multiple-output spat computes
All checks were successful
Validate Operations / validate-operations (push) Successful in 1h2m3s
All checks were successful
Validate Operations / validate-operations (push) Successful in 1h2m3s
This commit is contained in:
@@ -182,7 +182,7 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
mlir::ValueRange inputs,
|
mlir::ValueRange inputs,
|
||||||
BodyFn&& body) {
|
BodyFn&& body) {
|
||||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
auto* block = new mlir::Block();
|
||||||
for (mlir::Value input : inputs)
|
for (mlir::Value input : inputs)
|
||||||
@@ -198,10 +198,10 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
if (mlir::failed(bodyResult)) {
|
if (mlir::failed(bodyResult)) {
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
rewriter.eraseOp(computeOp);
|
rewriter.eraseOp(computeOp);
|
||||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||||
}
|
}
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
||||||
@@ -219,7 +219,7 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
mlir::ValueRange weights,
|
mlir::ValueRange weights,
|
||||||
mlir::ValueRange inputs,
|
mlir::ValueRange inputs,
|
||||||
BodyFn&& body) {
|
BodyFn&& body) {
|
||||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
auto* block = new mlir::Block();
|
||||||
for (mlir::Value input : inputs)
|
for (mlir::Value input : inputs)
|
||||||
@@ -234,10 +234,10 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
if (mlir::failed(bodyResult)) {
|
if (mlir::failed(bodyResult)) {
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
rewriter.eraseOp(computeOp);
|
rewriter.eraseOp(computeOp);
|
||||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||||
}
|
}
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
if (coresCount != -1) {
|
if (coresCount != -1) {
|
||||||
int computeOpsCount = 0;
|
int computeOpsCount = 0;
|
||||||
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
|
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
|
||||||
if (isa<spatial::SpatWeightedCompute>(op))
|
if (isa<spatial::SpatCompute>(op))
|
||||||
computeOpsCount++;
|
computeOpsCount++;
|
||||||
|
|
||||||
if (computeOpsCount > coresCount) {
|
if (computeOpsCount > coresCount) {
|
||||||
@@ -167,16 +167,16 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
|
|||||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
||||||
Value source = funcSource(toRemoveOp);
|
Value source = funcSource(toRemoveOp);
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||||
if (isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp())) {
|
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
|
||||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source);
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||||
rewriter.setInsertionPointToEnd(BB);
|
rewriter.setInsertionPointToEnd(BB);
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
mapper.map(source, BB->getArgument(0));
|
mapper.map(source, BB->getArgument(0));
|
||||||
auto newInst = rewriter.clone(*inst, mapper);
|
auto newInst = rewriter.clone(*inst, mapper);
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0));
|
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||||
inst->replaceAllUsesWith(newCompute);
|
inst->replaceAllUsesWith(newCompute->getResults());
|
||||||
inst->erase();
|
inst->erase();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -189,8 +189,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
|||||||
auto sources = toRemoveOp.getInputs();
|
auto sources = toRemoveOp.getInputs();
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||||
if (llvm::any_of(
|
if (llvm::any_of(
|
||||||
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) {
|
sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
||||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||||
SmallVector<Type> sourceTypes;
|
SmallVector<Type> sourceTypes;
|
||||||
SmallVector<Location> sourceLoc;
|
SmallVector<Location> sourceLoc;
|
||||||
for (auto source : sources) {
|
for (auto source : sources) {
|
||||||
@@ -204,8 +204,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
|||||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||||
mapper.map(source, bbArg);
|
mapper.map(source, bbArg);
|
||||||
auto newConcat = rewriter.clone(*inst, mapper);
|
auto newConcat = rewriter.clone(*inst, mapper);
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0));
|
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
||||||
inst->replaceAllUsesWith(newCompute);
|
inst->replaceAllUsesWith(newCompute->getResults());
|
||||||
inst->erase();
|
inst->erase();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -298,14 +298,15 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|||||||
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
SmallVector<spatial::SpatWeightedCompute> trivialComputes;
|
SmallVector<spatial::SpatCompute> trivialComputes;
|
||||||
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase;
|
llvm::SmallSet<spatial::SpatCompute, 8> toErase;
|
||||||
|
|
||||||
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
|
||||||
if (compute->hasOneUse()) {
|
if (compute->hasOneUse()) {
|
||||||
auto user = dyn_cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
auto& use = *compute->getUses().begin();
|
||||||
|
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
||||||
|
|
||||||
if (user && user.getInputs().size() == 1)
|
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
||||||
trivialComputes.push_back(compute);
|
trivialComputes.push_back(compute);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -317,12 +318,15 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
trivialComputes.pop_back();
|
trivialComputes.pop_back();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
auto& computeUse = *compute->getUses().begin();
|
||||||
|
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
|
||||||
|
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
|
||||||
|
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||||
|
|
||||||
auto newCompute =
|
auto newCompute =
|
||||||
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||||
|
|
||||||
@@ -343,7 +347,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||||
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
|
||||||
newTerminator->erase();
|
newTerminator->erase();
|
||||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||||
for (auto& op : child.getBody().front()) {
|
for (auto& op : child.getBody().front()) {
|
||||||
@@ -371,14 +375,16 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
toErase.insert(compute);
|
toErase.insert(compute);
|
||||||
|
|
||||||
if (newCompute->hasOneUse()) {
|
if (newCompute->hasOneUse()) {
|
||||||
auto user = dyn_cast<spatial::SpatWeightedCompute>(*newCompute->getUsers().begin());
|
auto& use = *newCompute->getUses().begin();
|
||||||
if (user && user.getInputs().size() == 1)
|
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
||||||
|
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
||||||
trivialComputes.push_back(newCompute);
|
trivialComputes.push_back(newCompute);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto compute : toErase) {
|
for (auto compute : toErase) {
|
||||||
compute.getResult(0).dropAllUses();
|
for (Value result : compute->getResults())
|
||||||
|
result.dropAllUses();
|
||||||
compute.erase();
|
compute.erase();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -386,7 +392,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||||
bool isAlwaysWeight =
|
bool isAlwaysWeight =
|
||||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
|
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
|
||||||
if (isAlwaysWeight)
|
if (isAlwaysWeight)
|
||||||
markWeightAlways(constantOp);
|
markWeightAlways(constantOp);
|
||||||
});
|
});
|
||||||
@@ -394,7 +400,7 @@ void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
|||||||
|
|
||||||
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
|
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
SmallVector<spatial::SpatWeightedCompute> computes(funcOp.getOps<spatial::SpatWeightedCompute>());
|
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||||
|
|
||||||
for (auto compute : computes) {
|
for (auto compute : computes) {
|
||||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||||
@@ -430,7 +436,7 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto newCompute =
|
auto newCompute =
|
||||||
spatial::SpatWeightedCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||||
auto* newBlock =
|
auto* newBlock =
|
||||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
|
|||||||
@@ -147,33 +147,37 @@ static Value buildPackedBias(bool hasBias,
|
|||||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createIm2colCompute(Value x,
|
static SmallVector<Value> createIm2colRowComputes(Value x,
|
||||||
RankedTensorType xType,
|
RankedTensorType xType,
|
||||||
RankedTensorType im2colType,
|
RankedTensorType im2colType,
|
||||||
RankedTensorType rowType,
|
RankedTensorType im2colRowType,
|
||||||
int64_t batchSize,
|
RankedTensorType gemmInputRowType,
|
||||||
int64_t numChannelsIn,
|
int64_t batchSize,
|
||||||
int64_t xHeight,
|
int64_t numChannelsIn,
|
||||||
int64_t xWidth,
|
int64_t xHeight,
|
||||||
int64_t wHeight,
|
int64_t xWidth,
|
||||||
int64_t wWidth,
|
int64_t wHeight,
|
||||||
int64_t padHeightBegin,
|
int64_t wWidth,
|
||||||
int64_t padHeightEnd,
|
int64_t padHeightBegin,
|
||||||
int64_t padWidthBegin,
|
int64_t padHeightEnd,
|
||||||
int64_t padWidthEnd,
|
int64_t padWidthBegin,
|
||||||
int64_t strideHeight,
|
int64_t padWidthEnd,
|
||||||
int64_t strideWidth,
|
int64_t strideHeight,
|
||||||
int64_t dilationHeight,
|
int64_t strideWidth,
|
||||||
int64_t dilationWidth,
|
int64_t dilationHeight,
|
||||||
int64_t outWidth,
|
int64_t dilationWidth,
|
||||||
int64_t patchSize,
|
int64_t outWidth,
|
||||||
int64_t numPatches,
|
int64_t patchSize,
|
||||||
int64_t numPatchesPerBatch,
|
int64_t numPatches,
|
||||||
ConversionPatternRewriter& rewriter,
|
int64_t numPatchesPerBatch,
|
||||||
Location loc) {
|
int64_t packFactor,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
auto elemType = xType.getElementType();
|
auto elemType = xType.getElementType();
|
||||||
constexpr size_t numInputs = 1;
|
constexpr size_t numInputs = 1;
|
||||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||||
|
SmallVector<Type> resultTypes(packedNumRows, gemmInputRowType);
|
||||||
|
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, resultTypes, {}, x, [&](Value xArg) {
|
||||||
Value paddedInput = xArg;
|
Value paddedInput = xArg;
|
||||||
|
|
||||||
// Pad input with zeros if needed:
|
// Pad input with zeros if needed:
|
||||||
@@ -240,7 +244,7 @@ static Value createIm2colCompute(Value x,
|
|||||||
|
|
||||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
rowType,
|
im2colRowType,
|
||||||
patch,
|
patch,
|
||||||
SmallVector<ReassociationIndices> {
|
SmallVector<ReassociationIndices> {
|
||||||
{0},
|
{0},
|
||||||
@@ -256,121 +260,117 @@ static Value createIm2colCompute(Value x,
|
|||||||
|
|
||||||
rewriter.setInsertionPointAfter(im2colLoop);
|
rewriter.setInsertionPointAfter(im2colLoop);
|
||||||
Value im2col = im2colLoop.getResult(0);
|
Value im2col = im2colLoop.getResult(0);
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
|
||||||
});
|
|
||||||
return im2colComputeOp.getResult(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value createPackedIm2colRows(Value im2col,
|
Value gemmInputRows = im2col;
|
||||||
RankedTensorType im2colType,
|
if (packFactor != 1) {
|
||||||
Type elemType,
|
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||||
int64_t numPatches,
|
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||||
int64_t patchSize,
|
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||||
int64_t packFactor,
|
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
||||||
ConversionPatternRewriter& rewriter,
|
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||||
Location loc) {
|
loc,
|
||||||
if (packFactor == 1)
|
groupedType,
|
||||||
return im2col;
|
paddedIm2col,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
{0, 1},
|
||||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
{2}
|
||||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
});
|
||||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
||||||
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) {
|
loc,
|
||||||
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc);
|
packedType,
|
||||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
groupedIm2col,
|
||||||
loc,
|
SmallVector<ReassociationIndices> {
|
||||||
groupedType,
|
{0},
|
||||||
paddedIm2col,
|
{1, 2}
|
||||||
SmallVector<ReassociationIndices> {
|
});
|
||||||
{0, 1},
|
|
||||||
{2}
|
|
||||||
});
|
|
||||||
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
packedType,
|
|
||||||
groupedIm2col,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0},
|
|
||||||
{1, 2}
|
|
||||||
});
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
|
|
||||||
});
|
|
||||||
return packedComputeOp.getResult(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value createUnpackedOutput(Value packedOutput,
|
|
||||||
RankedTensorType gemmOutType,
|
|
||||||
RankedTensorType outType,
|
|
||||||
int64_t numPatches,
|
|
||||||
int64_t numChannelsOut,
|
|
||||||
int64_t packFactor,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
if (packFactor == 1)
|
|
||||||
return packedOutput;
|
|
||||||
|
|
||||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
|
||||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
|
||||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
|
||||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
|
||||||
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) {
|
|
||||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
expandedType,
|
|
||||||
packedOutputArg,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0},
|
|
||||||
{1, 2}
|
|
||||||
});
|
|
||||||
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
paddedType,
|
|
||||||
expandedOutput,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1},
|
|
||||||
{2}
|
|
||||||
});
|
|
||||||
|
|
||||||
Value unpackedOutput = paddedOutput;
|
|
||||||
if (paddedNumPatches != numPatches) {
|
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
unpackedOutput =
|
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput);
|
SmallVector<Value> rowResults;
|
||||||
|
rowResults.reserve(packedNumRows);
|
||||||
|
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(packFactor * patchSize)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
rowResults.push_back(
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
|
||||||
|
}
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, rowResults);
|
||||||
});
|
});
|
||||||
return unpackComputeOp.getResult(0);
|
|
||||||
|
SmallVector<Value> rows;
|
||||||
|
rows.reserve(im2colComputeOp.getNumResults());
|
||||||
|
for (Value result : im2colComputeOp.getResults())
|
||||||
|
rows.push_back(result);
|
||||||
|
return rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createCollectedConvOutput(Value gemmOut,
|
static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||||
Type convType,
|
Type convType,
|
||||||
|
RankedTensorType gemmOutType,
|
||||||
RankedTensorType nhwcType,
|
RankedTensorType nhwcType,
|
||||||
RankedTensorType outType,
|
RankedTensorType outType,
|
||||||
|
int64_t numPatches,
|
||||||
|
int64_t numChannelsOut,
|
||||||
|
int64_t packFactor,
|
||||||
ConversionPatternRewriter& rewriter,
|
ConversionPatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
auto collectComputeOp =
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||||
createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) {
|
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||||
Value gemmOutArg = gemmOutArgs.front();
|
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||||
|
Value gemmOut;
|
||||||
// Restore to NCHW layout:
|
if (packFactor == 1) {
|
||||||
// [numPatches, numChannelsOut]
|
gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front()
|
||||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||||
// -> [1, numChannelsOut, outHeight, outWidth]
|
}
|
||||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
else {
|
||||||
loc,
|
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||||
nhwcType,
|
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||||
gemmOutArg,
|
Value packedOutput =
|
||||||
SmallVector<ReassociationIndices> {
|
gemmRowArgs.size() == 1
|
||||||
{0, 1, 2},
|
? gemmRowArgs.front()
|
||||||
{3}
|
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||||
|
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
expandedType,
|
||||||
|
packedOutput,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2}
|
||||||
});
|
});
|
||||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
loc,
|
||||||
|
paddedType,
|
||||||
|
expandedOutput,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
|
||||||
|
gemmOut = paddedOutput;
|
||||||
|
if (paddedNumPatches != numPatches) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore to NCHW layout:
|
||||||
|
// [numPatches, numChannelsOut]
|
||||||
|
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||||
|
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||||
|
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
nhwcType,
|
||||||
|
gemmOut,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1, 2},
|
||||||
|
{3}
|
||||||
});
|
});
|
||||||
|
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||||
|
});
|
||||||
return collectComputeOp.getResult(0);
|
return collectComputeOp.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -487,11 +487,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
|
|
||||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||||
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
Value biasMatrix;
|
Value biasMatrix;
|
||||||
DenseElementsAttr biasDenseAttr;
|
DenseElementsAttr biasDenseAttr;
|
||||||
if (hasB) {
|
if (hasB) {
|
||||||
gemmC = b;
|
gemmBias = b;
|
||||||
biasDenseAttr = getDenseConstantAttr(b);
|
biasDenseAttr = getDenseConstantAttr(b);
|
||||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||||
}
|
}
|
||||||
@@ -500,94 +500,86 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
const int64_t effectiveMaxParallelPixels =
|
const int64_t effectiveMaxParallelPixels =
|
||||||
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
|
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
|
||||||
|
|
||||||
Value im2col = createIm2colCompute(x,
|
// Keep the standard im2col view of convolution:
|
||||||
xType,
|
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||||
im2colType,
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
rowType,
|
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
|
||||||
batchSize,
|
//
|
||||||
numChannelsIn,
|
// The im2col compute yields each GEMM input row as a separate result so every GEMM consumes only
|
||||||
xHeight,
|
// the row it needs instead of receiving a full packed tensor and slicing it locally.
|
||||||
xWidth,
|
auto gemmInputRowType =
|
||||||
wHeight,
|
RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||||
wWidth,
|
auto gemmOutputRowType =
|
||||||
padHeightBegin,
|
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||||
padHeightEnd,
|
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
|
||||||
padWidthBegin,
|
xType,
|
||||||
padWidthEnd,
|
im2colType,
|
||||||
strideHeight,
|
rowType,
|
||||||
strideWidth,
|
gemmInputRowType,
|
||||||
dilationHeight,
|
batchSize,
|
||||||
dilationWidth,
|
numChannelsIn,
|
||||||
outWidth,
|
xHeight,
|
||||||
patchSize,
|
xWidth,
|
||||||
numPatches,
|
wHeight,
|
||||||
numPatchesPerBatch,
|
wWidth,
|
||||||
rewriter,
|
padHeightBegin,
|
||||||
loc);
|
padHeightEnd,
|
||||||
|
padWidthBegin,
|
||||||
|
padWidthEnd,
|
||||||
|
strideHeight,
|
||||||
|
strideWidth,
|
||||||
|
dilationHeight,
|
||||||
|
dilationWidth,
|
||||||
|
outWidth,
|
||||||
|
patchSize,
|
||||||
|
numPatches,
|
||||||
|
numPatchesPerBatch,
|
||||||
|
effectiveMaxParallelPixels,
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
|
||||||
Value gemmOut;
|
Value gemmB = buildPackedWeight(wDenseAttr,
|
||||||
if (effectiveMaxParallelPixels == 1) {
|
wTrans,
|
||||||
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels.
|
wType,
|
||||||
gemmOut = ONNXGemmOp::create(rewriter,
|
numChannelsIn,
|
||||||
loc,
|
numChannelsOut,
|
||||||
gemmOutType,
|
wHeight,
|
||||||
im2col,
|
wWidth,
|
||||||
wTrans,
|
patchSize,
|
||||||
gemmC,
|
effectiveMaxParallelPixels,
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
rewriter,
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
loc);
|
||||||
rewriter.getBoolAttr(false),
|
Value gemmC = buildPackedBias(
|
||||||
rewriter.getBoolAttr(false))
|
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||||
.getY();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// Keep the standard im2col view of convolution:
|
|
||||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
|
||||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
|
||||||
// but repack several old rows into one new row so we use the available crossbar size better.
|
|
||||||
//
|
|
||||||
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
|
|
||||||
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
|
||||||
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
|
||||||
// A_packed: [ceil(numPatches / N), N * patchSize]
|
|
||||||
// B_packed: [N * patchSize, N * cOut]
|
|
||||||
// Y_packed: [ceil(numPatches / N), N * cOut]
|
|
||||||
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
|
|
||||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
|
||||||
auto packedOutType =
|
|
||||||
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
|
||||||
|
|
||||||
Value packedA = createPackedIm2colRows(
|
SmallVector<Value> gemmRows;
|
||||||
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc);
|
gemmRows.reserve(gemmInputRows.size());
|
||||||
Value packedB = buildPackedWeight(wDenseAttr,
|
for (Value gemmInputRow : gemmInputRows) {
|
||||||
wTrans,
|
Value gemmRow = ONNXGemmOp::create(rewriter,
|
||||||
wType,
|
loc,
|
||||||
numChannelsIn,
|
gemmOutputRowType,
|
||||||
numChannelsOut,
|
gemmInputRow,
|
||||||
wHeight,
|
gemmB,
|
||||||
wWidth,
|
gemmC,
|
||||||
patchSize,
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
effectiveMaxParallelPixels,
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
rewriter,
|
rewriter.getBoolAttr(false),
|
||||||
loc);
|
rewriter.getBoolAttr(false))
|
||||||
Value packedC = buildPackedBias(
|
.getY();
|
||||||
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
gemmRows.push_back(gemmRow);
|
||||||
Value packedOut = ONNXGemmOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
packedOutType,
|
|
||||||
packedA,
|
|
||||||
packedB,
|
|
||||||
packedC,
|
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
|
||||||
rewriter.getBoolAttr(false),
|
|
||||||
rewriter.getBoolAttr(false))
|
|
||||||
.getY();
|
|
||||||
gemmOut = createUnpackedOutput(
|
|
||||||
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc));
|
rewriter.replaceOp(convOp,
|
||||||
|
createCollectedConvOutput(gemmRows,
|
||||||
|
convOp.getType(),
|
||||||
|
gemmOutType,
|
||||||
|
nhwcType,
|
||||||
|
outType,
|
||||||
|
numPatches,
|
||||||
|
numChannelsOut,
|
||||||
|
effectiveMaxParallelPixels,
|
||||||
|
rewriter,
|
||||||
|
loc));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -42,15 +42,15 @@ private:
|
|||||||
raw_ostream& os;
|
raw_ostream& os;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Draws the subgraph for a given spatial::SpatWeightedCompute, including:
|
* Draws the subgraph for a given spatial::SpatCompute, including:
|
||||||
* 1. Input nodes (block arguments)
|
* 1. Input nodes (block arguments)
|
||||||
* 2. Operations
|
* 2. Operations
|
||||||
* 3. Edges between yield (output) and its users
|
* 3. Edges between yield (output) and its users
|
||||||
*
|
*
|
||||||
* @param op The spatial::SpatWeightedCompute to draw the subgraph for.
|
* @param op The spatial::SpatCompute to draw the subgraph for.
|
||||||
* @param computeNum The number of the compute operation.
|
* @param computeNum The number of the compute operation.
|
||||||
*/
|
*/
|
||||||
void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) {
|
void drawComputeOpSubgraph(spatial::SpatCompute op, size_t computeNum) {
|
||||||
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n"
|
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n"
|
||||||
<< "\t\tstyle=filled;\n"
|
<< "\t\tstyle=filled;\n"
|
||||||
<< "\t\tcolor=lightblue;\n";
|
<< "\t\tcolor=lightblue;\n";
|
||||||
@@ -217,7 +217,7 @@ void SpatialToGraphvizPass::runOnOperation() {
|
|||||||
// 1. Print their subgraph
|
// 1. Print their subgraph
|
||||||
// 2. Print the edges from its inputs to its outputs
|
// 2. Print the edges from its inputs to its outputs
|
||||||
for (Operation& op : func.getOps()) {
|
for (Operation& op : func.getOps()) {
|
||||||
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||||
drawComputeOpSubgraph(computeOp, computeNum++);
|
drawComputeOpSubgraph(computeOp, computeNum++);
|
||||||
}
|
}
|
||||||
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ private:
|
|||||||
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
||||||
void
|
void
|
||||||
addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
|
addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
|
||||||
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
void replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp,
|
||||||
unsigned int argIndex,
|
unsigned int argIndex,
|
||||||
Value channelSourceOp,
|
Value channelSourceOp,
|
||||||
Value consumerValue,
|
Value consumerValue,
|
||||||
@@ -73,7 +73,7 @@ private:
|
|||||||
void annotateChannelCoreIds(func::FuncOp funcOp);
|
void annotateChannelCoreIds(func::FuncOp funcOp);
|
||||||
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
|
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
@@ -116,7 +116,7 @@ static size_t countComputeLeafUsers(Value value) {
|
|||||||
auto walkUses = [&](Value currentValue, auto& self) -> void {
|
auto walkUses = [&](Value currentValue, auto& self) -> void {
|
||||||
for (OpOperand& use : currentValue.getUses()) {
|
for (OpOperand& use : currentValue.getUses()) {
|
||||||
Operation* owner = use.getOwner();
|
Operation* owner = use.getOwner();
|
||||||
if (isa<spatial::SpatWeightedCompute>(owner)) {
|
if (isa<spatial::SpatCompute>(owner)) {
|
||||||
leafUserCount++;
|
leafUserCount++;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -174,7 +174,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
markOpToRemove(receiveOp);
|
markOpToRemove(receiveOp);
|
||||||
runOnReceiveOp(receiveOp, rewriter);
|
runOnReceiveOp(receiveOp, rewriter);
|
||||||
}
|
}
|
||||||
for (auto computeOp : funcOp.getOps<spatial::SpatWeightedCompute>()) {
|
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||||
markOpToRemove(computeOp);
|
markOpToRemove(computeOp);
|
||||||
runOnComputeOp(computeOp, rewriter);
|
runOnComputeOp(computeOp, rewriter);
|
||||||
}
|
}
|
||||||
@@ -222,7 +222,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
dumpModule(moduleOp, "pim0");
|
dumpModule(moduleOp, "pim0");
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||||
Location loc = computeOp->getLoc();
|
Location loc = computeOp->getLoc();
|
||||||
|
|
||||||
auto& block = computeOp.getRegion().front();
|
auto& block = computeOp.getRegion().front();
|
||||||
@@ -504,7 +504,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
|||||||
|
|
||||||
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
||||||
for (auto& op : funcOp.getBody().getOps())
|
for (auto& op : funcOp.getBody().getOps())
|
||||||
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||||
unsigned numComputeWeights = computeOp.getWeights().size();
|
unsigned numComputeWeights = computeOp.getWeights().size();
|
||||||
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
|
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
|
||||||
TypedValue<TensorType> tensorSource;
|
TypedValue<TensorType> tensorSource;
|
||||||
@@ -513,7 +513,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
|||||||
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
||||||
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
||||||
|
|
||||||
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
|
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
||||||
@@ -538,7 +538,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
|||||||
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
|
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
|
||||||
|
|
||||||
// Compute results must be transferred through channels via send/receive
|
// Compute results must be transferred through channels via send/receive
|
||||||
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
|
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
|
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
|
||||||
@@ -553,7 +553,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp,
|
||||||
unsigned int argIndex,
|
unsigned int argIndex,
|
||||||
Value channelSourceOp,
|
Value channelSourceOp,
|
||||||
Value consumerValue,
|
Value consumerValue,
|
||||||
@@ -614,7 +614,7 @@ void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
|
|||||||
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
|
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
|
||||||
for (OpOperand& use : currentValue.getUses()) {
|
for (OpOperand& use : currentValue.getUses()) {
|
||||||
Operation* owner = use.getOwner();
|
Operation* owner = use.getOwner();
|
||||||
if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
|
if (auto computeUser = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||||
replaceBlockArgumentWithRecvOp(
|
replaceBlockArgumentWithRecvOp(
|
||||||
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
|
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
|
||||||
continue;
|
continue;
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> {
|
|||||||
// Execution
|
// Execution
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||||
let summary = "Compute region with attached constant weights";
|
let summary = "Compute region with attached constant weights";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
|
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
|
||||||
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp());
|
auto wcomputeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp());
|
||||||
if (wcomputeOp)
|
if (wcomputeOp)
|
||||||
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
|
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
|
||||||
|
|
||||||
@@ -134,7 +134,7 @@ llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigth
|
|||||||
LogicalResult SpatWeightedMVMOp::verify() {
|
LogicalResult SpatWeightedMVMOp::verify() {
|
||||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||||
if (failed(matrixShapeOpt))
|
if (failed(matrixShapeOpt))
|
||||||
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
|
||||||
auto matrixShape = *matrixShapeOpt;
|
auto matrixShape = *matrixShapeOpt;
|
||||||
auto vectorShape = getInput().getType().getShape();
|
auto vectorShape = getInput().getType().getShape();
|
||||||
auto outputShape = getOutput().getType().getShape();
|
auto outputShape = getOutput().getType().getShape();
|
||||||
@@ -155,7 +155,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
|
|||||||
LogicalResult SpatWeightedVMMOp::verify() {
|
LogicalResult SpatWeightedVMMOp::verify() {
|
||||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||||
if (failed(matrixShapeOpt))
|
if (failed(matrixShapeOpt))
|
||||||
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
|
||||||
auto matrixShape = *matrixShapeOpt;
|
auto matrixShape = *matrixShapeOpt;
|
||||||
auto vectorShape = getInput().getType().getShape();
|
auto vectorShape = getInput().getType().getShape();
|
||||||
auto outputShape = getOutput().getType().getShape();
|
auto outputShape = getOutput().getType().getShape();
|
||||||
@@ -200,9 +200,8 @@ LogicalResult SpatVMaxOp::verify() {
|
|||||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatWeightedCompute::verify() {
|
LogicalResult SpatCompute::verify() {
|
||||||
// Check that it has a terminator, it is a yieldOp, and it has a single
|
// Check that the terminator yields the same number and types as the compute results.
|
||||||
// operand with the same type as the result
|
|
||||||
auto& block = getBody().front();
|
auto& block = getBody().front();
|
||||||
if (block.mightHaveTerminator()) {
|
if (block.mightHaveTerminator()) {
|
||||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||||
@@ -257,7 +256,7 @@ LogicalResult SpatWeightedCompute::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatWeightedCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||||
Block& block = getBody().front();
|
Block& block = getBody().front();
|
||||||
if (!llvm::hasSingleElement(block))
|
if (!llvm::hasSingleElement(block))
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -74,15 +74,15 @@ std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
|||||||
return aggregatedEdges;
|
return aggregatedEdges;
|
||||||
}
|
}
|
||||||
|
|
||||||
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
|
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatCompute> spatComputes,
|
||||||
llvm::ArrayRef<IndexedEdge> edges) {
|
llvm::ArrayRef<IndexedEdge> edges) {
|
||||||
VirtualGraph graph;
|
VirtualGraph graph;
|
||||||
graph.nodes.reserve(spatWeightedComputes.size());
|
graph.nodes.reserve(spatComputes.size());
|
||||||
for (auto [index, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
|
for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||||
VirtualNode node;
|
VirtualNode node;
|
||||||
node.originalComputeIndices.push_back(index);
|
node.originalComputeIndices.push_back(index);
|
||||||
node.weight = getSpatComputeWeight(spatWeightedCompute);
|
node.weight = getSpatComputeWeight(spatCompute);
|
||||||
node.crossbarUsage = getSpatComputeCrossbarUsage(spatWeightedCompute);
|
node.crossbarUsage = getSpatComputeCrossbarUsage(spatCompute);
|
||||||
graph.nodes.push_back(std::move(node));
|
graph.nodes.push_back(std::move(node));
|
||||||
}
|
}
|
||||||
graph.edges = aggregateEdges(edges);
|
graph.edges = aggregateEdges(edges);
|
||||||
@@ -344,22 +344,22 @@ std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::A
|
|||||||
}
|
}
|
||||||
|
|
||||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
|
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
|
||||||
llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
|
llvm::ArrayRef<SpatCompute> spatComputes,
|
||||||
llvm::ArrayRef<IndexedEdge> originalEdges) {
|
llvm::ArrayRef<IndexedEdge> originalEdges) {
|
||||||
DCPAnalysisResult result;
|
DCPAnalysisResult result;
|
||||||
std::vector<size_t> originalToVirtualNode(spatWeightedComputes.size(), 0);
|
std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0);
|
||||||
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
|
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
|
||||||
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
||||||
originalToVirtualNode[originalIndex] = virtualNodeIndex;
|
originalToVirtualNode[originalIndex] = virtualNodeIndex;
|
||||||
|
|
||||||
auto dominanceOrder = computeOriginalTopologicalOrder(spatWeightedComputes.size(), originalEdges);
|
auto dominanceOrder = computeOriginalTopologicalOrder(spatComputes.size(), originalEdges);
|
||||||
result.dominanceOrderCompute.reserve(dominanceOrder.size());
|
result.dominanceOrderCompute.reserve(dominanceOrder.size());
|
||||||
for (size_t originalIndex : dominanceOrder) {
|
for (size_t originalIndex : dominanceOrder) {
|
||||||
SpatWeightedCompute spatWeightedCompute = spatWeightedComputes[originalIndex];
|
SpatCompute spatCompute = spatComputes[originalIndex];
|
||||||
size_t cpu = originalToVirtualNode[originalIndex];
|
size_t cpu = originalToVirtualNode[originalIndex];
|
||||||
result.dominanceOrderCompute.push_back(spatWeightedCompute);
|
result.dominanceOrderCompute.push_back(spatCompute);
|
||||||
result.computeToCpuMap[spatWeightedCompute] = cpu;
|
result.computeToCpuMap[spatCompute] = cpu;
|
||||||
result.cpuToLastComputeMap[cpu] = spatWeightedCompute;
|
result.cpuToLastComputeMap[cpu] = spatCompute;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto [cpu, lastCompute] : result.cpuToLastComputeMap)
|
for (auto [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||||
@@ -367,10 +367,10 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
|
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatCompute> spatComputes,
|
||||||
llvm::ArrayRef<IndexedEdge> edges,
|
llvm::ArrayRef<IndexedEdge> edges,
|
||||||
MLIRContext* context) {
|
MLIRContext* context) {
|
||||||
GraphDCP graphDCP(spatWeightedComputes, edges);
|
GraphDCP graphDCP(spatComputes, edges);
|
||||||
if (coresCount.getValue() > 0)
|
if (coresCount.getValue() > 0)
|
||||||
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||||
graphDCP.setContext(context);
|
graphDCP.setContext(context);
|
||||||
@@ -380,7 +380,7 @@ DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatWeightedCompute> spatWeightedC
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
|
SpatCompute getOriginalSpatCompute(Operation* op) {
|
||||||
if (!op)
|
if (!op)
|
||||||
return {};
|
return {};
|
||||||
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
|
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
@@ -388,39 +388,33 @@ SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
|
|||||||
if (!op)
|
if (!op)
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
if (auto res = llvm::dyn_cast<SpatWeightedCompute>(op))
|
if (auto res = llvm::dyn_cast<SpatCompute>(op))
|
||||||
return res;
|
return res;
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
DCPAnalysisResult DCPAnalysis::run() {
|
DCPAnalysisResult DCPAnalysis::run() {
|
||||||
SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
|
SmallVector<SpatCompute, 10> spatComputes;
|
||||||
SmallVector<IndexedEdge, 10> edges;
|
SmallVector<IndexedEdge, 10> edges;
|
||||||
for (auto& region : entryOp->getRegions())
|
for (auto& region : entryOp->getRegions())
|
||||||
for (SpatWeightedCompute spatWeightedCompute : region.getOps<SpatWeightedCompute>())
|
for (SpatCompute spatCompute : region.getOps<SpatCompute>())
|
||||||
spatWeightedComputes.push_back(spatWeightedCompute);
|
spatComputes.push_back(spatCompute);
|
||||||
|
|
||||||
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
|
for (auto [indexEndEdge, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||||
for (Value input : spatWeightedCompute.getInputs()) {
|
for (Value input : spatCompute.getInputs()) {
|
||||||
if (auto producerCompute = getOriginalSpatWeightedCompute(input.getDefiningOp())) {
|
if (auto producerCompute = getOriginalSpatCompute(input.getDefiningOp())) {
|
||||||
auto producerIt = llvm::find(spatWeightedComputes, producerCompute);
|
auto producerIt = llvm::find(spatComputes, producerCompute);
|
||||||
assert(producerIt != spatWeightedComputes.end());
|
assert(producerIt != spatComputes.end());
|
||||||
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), producerIt);
|
auto indexStartEdge = std::distance(spatComputes.begin(), producerIt);
|
||||||
ResultRange outputs = producerCompute.getResults();
|
edges.push_back({indexStartEdge, indexEndEdge, getSizeInBytes(cast<ShapedType>(input.getType()))});
|
||||||
int64_t totalSize = 0;
|
|
||||||
for (auto output : outputs) {
|
|
||||||
ShapedType resultType = cast<ShapedType>(output.getType());
|
|
||||||
totalSize += getSizeInBytes(resultType);
|
|
||||||
}
|
|
||||||
edges.push_back({indexStartEdge, indexEndEdge, totalSize});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dcpCriticalWindowSize.getValue() == 0)
|
if (dcpCriticalWindowSize.getValue() == 0)
|
||||||
return runLegacyDcp(spatWeightedComputes, edges, entryOp->getContext());
|
return runLegacyDcp(spatComputes, edges, entryOp->getContext());
|
||||||
|
|
||||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatWeightedComputes, edges);
|
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatComputes, edges);
|
||||||
std::set<std::vector<size_t>> seenCriticalWindows;
|
std::set<std::vector<size_t>> seenCriticalWindows;
|
||||||
while (virtualGraph.nodes.size() > 1) {
|
while (virtualGraph.nodes.size() > 1) {
|
||||||
TimingInfo timing = computeTiming(virtualGraph);
|
TimingInfo timing = computeTiming(virtualGraph);
|
||||||
@@ -446,7 +440,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return buildResultFromVirtualGraph(virtualGraph, spatWeightedComputes, edges);
|
return buildResultFromVirtualGraph(virtualGraph, spatComputes, edges);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -10,10 +10,10 @@
|
|||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
struct DCPAnalysisResult {
|
struct DCPAnalysisResult {
|
||||||
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
|
std::vector<onnx_mlir::spatial::SpatCompute> dominanceOrderCompute;
|
||||||
llvm::DenseMap<onnx_mlir::spatial::SpatWeightedCompute, size_t> computeToCpuMap;
|
llvm::DenseMap<onnx_mlir::spatial::SpatCompute, size_t> computeToCpuMap;
|
||||||
llvm::DenseSet<onnx_mlir::spatial::SpatWeightedCompute> isLastComputeOfCpu;
|
llvm::DenseSet<onnx_mlir::spatial::SpatCompute> isLastComputeOfCpu;
|
||||||
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatWeightedCompute> cpuToLastComputeMap;
|
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatCompute> cpuToLastComputeMap;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|||||||
@@ -1260,7 +1260,7 @@ DCPAnalysisResult GraphDCP::getResult() {
|
|||||||
auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size());
|
auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size());
|
||||||
ret.dominanceOrderCompute.reserve(dominanceOrder.size());
|
ret.dominanceOrderCompute.reserve(dominanceOrder.size());
|
||||||
for (auto elem : dominanceOrder)
|
for (auto elem : dominanceOrder)
|
||||||
ret.dominanceOrderCompute.push_back(elem->getSpatWeightedCompute());
|
ret.dominanceOrderCompute.push_back(elem->getSpatCompute());
|
||||||
|
|
||||||
for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) {
|
for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) {
|
||||||
const CpuTaskList* tasks = findCpuTasks(cpu);
|
const CpuTaskList* tasks = findCpuTasks(cpu);
|
||||||
@@ -1268,10 +1268,10 @@ DCPAnalysisResult GraphDCP::getResult() {
|
|||||||
continue;
|
continue;
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for (auto node : *tasks) {
|
for (auto node : *tasks) {
|
||||||
ret.computeToCpuMap[node->getSpatWeightedCompute()] = cpu;
|
ret.computeToCpuMap[node->getSpatCompute()] = cpu;
|
||||||
if (i++ == tasks->size() - 1) {
|
if (i++ == tasks->size() - 1) {
|
||||||
ret.isLastComputeOfCpu.insert(node->getSpatWeightedCompute());
|
ret.isLastComputeOfCpu.insert(node->getSpatCompute());
|
||||||
ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute();
|
ret.cpuToLastComputeMap[cpu] = node->getSpatCompute();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -115,11 +115,11 @@ private:
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
void runDcp();
|
void runDcp();
|
||||||
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes,
|
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes,
|
||||||
llvm::ArrayRef<IndexedEdge> edges)
|
llvm::ArrayRef<IndexedEdge> edges)
|
||||||
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
||||||
for (auto spatWeightedCompute : spatWeightedComputes)
|
for (auto spatCompute : spatComputes)
|
||||||
nodes.emplace_back(spatWeightedCompute);
|
nodes.emplace_back(spatCompute);
|
||||||
for (auto [start, end, weight] : edges)
|
for (auto [start, end, weight] : edges)
|
||||||
makeEdge(start, end, weight);
|
makeEdge(start, end, weight);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> {
|
class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> {
|
||||||
onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute;
|
onnx_mlir::spatial::SpatCompute spatCompute;
|
||||||
Time aest;
|
Time aest;
|
||||||
Time alst;
|
Time alst;
|
||||||
std::optional<CPU> scheduledCpu;
|
std::optional<CPU> scheduledCpu;
|
||||||
@@ -38,22 +38,22 @@ public:
|
|||||||
std::vector<Edge> parents;
|
std::vector<Edge> parents;
|
||||||
std::vector<Edge> children;
|
std::vector<Edge> children;
|
||||||
TaskDCP() = default;
|
TaskDCP() = default;
|
||||||
TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute)
|
TaskDCP(onnx_mlir::spatial::SpatCompute spatCompute)
|
||||||
: onnx_mlir::LabeledListNode<TaskDCP>(),
|
: onnx_mlir::LabeledListNode<TaskDCP>(),
|
||||||
spatWeightedCompute(spatWeightedCompute),
|
spatCompute(spatCompute),
|
||||||
aest(0),
|
aest(0),
|
||||||
alst(0),
|
alst(0),
|
||||||
scheduledCpu(),
|
scheduledCpu(),
|
||||||
weight(getSpatComputeWeight(spatWeightedCompute)),
|
weight(getSpatComputeWeight(spatCompute)),
|
||||||
baseWeight(weight),
|
baseWeight(weight),
|
||||||
crossbarUsage(getSpatComputeCrossbarUsage(spatWeightedCompute)),
|
crossbarUsage(getSpatComputeCrossbarUsage(spatCompute)),
|
||||||
syntheticId(-1),
|
syntheticId(-1),
|
||||||
parents(),
|
parents(),
|
||||||
children() {}
|
children() {}
|
||||||
|
|
||||||
TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0)
|
TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0)
|
||||||
: onnx_mlir::LabeledListNode<TaskDCP>(),
|
: onnx_mlir::LabeledListNode<TaskDCP>(),
|
||||||
spatWeightedCompute(),
|
spatCompute(),
|
||||||
aest(0),
|
aest(0),
|
||||||
alst(0),
|
alst(0),
|
||||||
scheduledCpu(),
|
scheduledCpu(),
|
||||||
@@ -90,14 +90,14 @@ public:
|
|||||||
void setAlst(Time value) { alst = value; }
|
void setAlst(Time value) { alst = value; }
|
||||||
bool hasDescendant(TaskDCP* child);
|
bool hasDescendant(TaskDCP* child);
|
||||||
int64_t Id() const {
|
int64_t Id() const {
|
||||||
if (spatWeightedCompute)
|
if (spatCompute)
|
||||||
return reinterpret_cast<int64_t>(spatWeightedCompute.getAsOpaquePointer());
|
return reinterpret_cast<int64_t>(spatCompute.getAsOpaquePointer());
|
||||||
return syntheticId;
|
return syntheticId;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isCriticalPath() const { return alst == aest; }
|
bool isCriticalPath() const { return alst == aest; }
|
||||||
bool isScheduled() const { return scheduledCpu.has_value(); }
|
bool isScheduled() const { return scheduledCpu.has_value(); }
|
||||||
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() const { return spatWeightedCompute; }
|
onnx_mlir::spatial::SpatCompute getSpatCompute() const { return spatCompute; }
|
||||||
|
|
||||||
void setFlag(long long val) { flag = val; }
|
void setFlag(long long val) { flag = val; }
|
||||||
long long getFlag() const { return flag; }
|
long long getFlag() const { return flag; }
|
||||||
|
|||||||
@@ -92,18 +92,18 @@ inline T subtractOrZero(T lhs, T rhs) {
|
|||||||
|
|
||||||
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
|
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
|
||||||
|
|
||||||
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
|
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatCompute spatCompute) {
|
||||||
constexpr Weight kOperationWeight = 100;
|
constexpr Weight kOperationWeight = 100;
|
||||||
Weight numOperations = 0;
|
Weight numOperations = 0;
|
||||||
for (auto& block : spatWeightedCompute.getBody())
|
for (auto& block : spatCompute.getBody())
|
||||||
for ([[maybe_unused]] auto& op : block)
|
for ([[maybe_unused]] auto& op : block)
|
||||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||||
return checkedMultiply(numOperations, kOperationWeight);
|
return checkedMultiply(numOperations, kOperationWeight);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
|
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute spatCompute) {
|
||||||
CrossbarUsage crossbarUsage = 0;
|
CrossbarUsage crossbarUsage = 0;
|
||||||
for (auto& region : spatWeightedCompute.getBody())
|
for (auto& region : spatCompute.getBody())
|
||||||
for (auto& inst : region)
|
for (auto& inst : region)
|
||||||
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst))
|
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst))
|
||||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||||
|
|||||||
@@ -24,30 +24,29 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
using SpatWeightedCompute = spatial::SpatWeightedCompute;
|
using SpatCompute = spatial::SpatCompute;
|
||||||
|
|
||||||
struct ComputeValueResults {
|
struct ComputeValueResults {
|
||||||
// Value yielded by the yieldOp
|
SmallVector<Value> innerValues;
|
||||||
Value innerValue;
|
|
||||||
|
Value get(size_t resultIndex) const {
|
||||||
|
assert(resultIndex < innerValues.size() && "compute result index out of range");
|
||||||
|
return innerValues[resultIndex];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class LazyInsertComputeResult {
|
class LazyInsertComputeResult {
|
||||||
using InsertPoint = mlir::IRRewriter::InsertPoint;
|
using InsertPoint = mlir::IRRewriter::InsertPoint;
|
||||||
ComputeValueResults computeResults;
|
ComputeValueResults computeResults;
|
||||||
Value channelValue;
|
|
||||||
bool onlyChannel;
|
bool onlyChannel;
|
||||||
std::function<void(InsertPoint insertPoint)> channelSendInserter;
|
std::function<std::pair<Value, std::function<void(InsertPoint)>>(size_t)> channelNewInserter;
|
||||||
InsertPoint sendInsertPoint;
|
|
||||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LazyInsertComputeResult(ComputeValueResults computeValueResults,
|
LazyInsertComputeResult(ComputeValueResults computeValueResults,
|
||||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter,
|
std::function<std::pair<Value, std::function<void(InsertPoint)>>(size_t)> channelNewInserter,
|
||||||
bool isOnlyChannel)
|
bool isOnlyChannel)
|
||||||
: computeResults(computeValueResults),
|
: computeResults(computeValueResults),
|
||||||
onlyChannel(isOnlyChannel),
|
onlyChannel(isOnlyChannel),
|
||||||
channelSendInserter(nullptr),
|
|
||||||
sendInsertPoint({}),
|
|
||||||
channelNewInserter(channelNewInserter) {}
|
channelNewInserter(channelNewInserter) {}
|
||||||
|
|
||||||
struct ChannelOrLocalOp {
|
struct ChannelOrLocalOp {
|
||||||
@@ -57,12 +56,12 @@ public:
|
|||||||
|
|
||||||
bool onlyChanneled() const { return onlyChannel; }
|
bool onlyChanneled() const { return onlyChannel; }
|
||||||
|
|
||||||
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute currentCompute) {
|
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatCompute currentCompute, size_t resultIndex) {
|
||||||
|
Value innerValue = computeResults.get(resultIndex);
|
||||||
|
|
||||||
auto [newChannelValue, senderInserter] = channelNewInserter();
|
auto [channelValue, channelSendInserter] = channelNewInserter(resultIndex);
|
||||||
channelValue = newChannelValue;
|
InsertPoint sendInsertPoint;
|
||||||
channelSendInserter = senderInserter;
|
auto* block = innerValue.getParentBlock();
|
||||||
auto* block = computeResults.innerValue.getParentBlock();
|
|
||||||
if (!block->empty() && isa<spatial::SpatYieldOp>(block->back()))
|
if (!block->empty() && isa<spatial::SpatYieldOp>(block->back()))
|
||||||
sendInsertPoint = InsertPoint(block, --block->end());
|
sendInsertPoint = InsertPoint(block, --block->end());
|
||||||
else
|
else
|
||||||
@@ -70,28 +69,30 @@ public:
|
|||||||
if (currentCompute) {
|
if (currentCompute) {
|
||||||
for (auto& block : currentCompute.getBody())
|
for (auto& block : currentCompute.getBody())
|
||||||
if (&block == sendInsertPoint.getBlock())
|
if (&block == sendInsertPoint.getBlock())
|
||||||
return {computeResults.innerValue, false};
|
return {innerValue, false};
|
||||||
}
|
}
|
||||||
channelSendInserter(sendInsertPoint);
|
channelSendInserter(sendInsertPoint);
|
||||||
return {channelValue, true};
|
return {channelValue, true};
|
||||||
}
|
}
|
||||||
|
|
||||||
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
|
ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex) {
|
||||||
|
return getAsChannelValueAndInsertSender({}, resultIndex);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
|
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
|
DenseMap<SpatCompute, LazyInsertComputeResult> newComputeNodeResults;
|
||||||
DenseMap<SpatWeightedCompute, SpatWeightedCompute> oldToNewComputeMap;
|
DenseMap<SpatCompute, SpatCompute> oldToNewComputeMap;
|
||||||
DenseMap<int64_t, SpatWeightedCompute> cpuToNewComputeMap;
|
DenseMap<int64_t, SpatCompute> cpuToNewComputeMap;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
|
||||||
|
|
||||||
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
|
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
|
||||||
StringRef getDescription() const override {
|
StringRef getDescription() const override {
|
||||||
return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total "
|
return "Merge Spatial-Compute-Nodes in order to reduce the total "
|
||||||
"execution time";
|
"execution time";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,22 +106,22 @@ public:
|
|||||||
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
|
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
|
||||||
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
|
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
|
||||||
if (!cpuToNewComputeMap.contains(cpu)) {
|
if (!cpuToNewComputeMap.contains(cpu)) {
|
||||||
ValueTypeRange<ResultRange> newWeightedComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
|
ValueTypeRange<ResultRange> newComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
|
||||||
auto [newWeightedCompute, computeValueResult] = createNewComputeNode(
|
auto [newCompute, computeValueResult] = createNewComputeNode(
|
||||||
currentComputeNode, newWeightedComputeType, lastComputeOfCpu.contains(currentComputeNode));
|
currentComputeNode, newComputeType, lastComputeOfCpu.contains(currentComputeNode));
|
||||||
cpuToNewComputeMap[cpu] = newWeightedCompute;
|
cpuToNewComputeMap[cpu] = newCompute;
|
||||||
newComputeNodeResults.insert(
|
newComputeNodeResults.insert(
|
||||||
std::make_pair(currentComputeNode,
|
std::make_pair(currentComputeNode,
|
||||||
createLazyComputeResult(
|
createLazyComputeResult(
|
||||||
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto [newWeightedCompute, computeValueResult] = mergeIntoComputeNode(
|
auto [newCompute, computeValueResult] = mergeIntoComputeNode(
|
||||||
cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
|
cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
|
||||||
newComputeNodeResults.insert(
|
newComputeNodeResults.insert(
|
||||||
std::make_pair(currentComputeNode,
|
std::make_pair(currentComputeNode,
|
||||||
createLazyComputeResult(
|
createLazyComputeResult(
|
||||||
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,8 +135,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::pair<SpatWeightedCompute, ComputeValueResults> createNewComputeNode(
|
std::pair<SpatCompute, ComputeValueResults> createNewComputeNode(
|
||||||
SpatWeightedCompute oldWeightedCompute, ValueTypeRange<ResultRange> newWeightedComputeType, bool lastCompute) {
|
SpatCompute oldCompute, ValueTypeRange<ResultRange> newComputeType, bool lastCompute) {
|
||||||
func::FuncOp func = getOperation();
|
func::FuncOp func = getOperation();
|
||||||
auto loc = func.getLoc();
|
auto loc = func.getLoc();
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
@@ -148,50 +149,53 @@ private:
|
|||||||
llvm::SmallVector<Type> newBBOperandType;
|
llvm::SmallVector<Type> newBBOperandType;
|
||||||
llvm::SmallVector<Location> newBBLocations;
|
llvm::SmallVector<Location> newBBLocations;
|
||||||
|
|
||||||
for (auto arg : oldWeightedCompute.getWeights())
|
for (auto arg : oldCompute.getWeights())
|
||||||
newComputeOperand.push_back(arg);
|
newComputeOperand.push_back(arg);
|
||||||
|
|
||||||
for (auto arg : oldWeightedCompute.getInputs())
|
for (auto arg : oldCompute.getInputs())
|
||||||
if (!llvm::isa_and_present<SpatWeightedCompute>(arg.getDefiningOp())) {
|
if (!llvm::isa_and_present<SpatCompute>(arg.getDefiningOp())) {
|
||||||
newComputeOperand.push_back(arg);
|
newComputeOperand.push_back(arg);
|
||||||
newBBOperandType.push_back(arg.getType());
|
newBBOperandType.push_back(arg.getType());
|
||||||
newBBLocations.push_back(loc);
|
newBBLocations.push_back(loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand);
|
auto newCompute = SpatCompute::create(rewriter, loc, newComputeType, newComputeOperand);
|
||||||
|
|
||||||
rewriter.createBlock(
|
rewriter.createBlock(
|
||||||
&newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations);
|
&newCompute.getBody(), newCompute.getBody().end(), newBBOperandType, newBBLocations);
|
||||||
newWeightedCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
{(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()});
|
{(int) oldCompute.getWeights().size(), (int) newBBOperandType.size()});
|
||||||
|
|
||||||
auto& newBB = newWeightedCompute.getBody().front();
|
auto& newBB = newCompute.getBody().front();
|
||||||
auto& oldBB = oldWeightedCompute.getBody().front();
|
auto& oldBB = oldCompute.getBody().front();
|
||||||
rewriter.setInsertionPointToEnd(&newBB);
|
rewriter.setInsertionPointToEnd(&newBB);
|
||||||
|
|
||||||
int indexNew = 0;
|
int indexNew = 0;
|
||||||
size_t indexOld = oldWeightedCompute.getWeights().size();
|
size_t indexOld = oldCompute.getWeights().size();
|
||||||
size_t indexOldStart = oldWeightedCompute.getWeights().size();
|
size_t indexOldStart = oldCompute.getWeights().size();
|
||||||
for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) {
|
for (; indexOld < oldCompute.getNumOperands(); ++indexOld) {
|
||||||
if (!llvm::isa_and_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) {
|
if (!llvm::isa_and_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp())) {
|
||||||
mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++));
|
mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto argWeightCompute =
|
auto argWeightCompute =
|
||||||
llvm::dyn_cast_if_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp());
|
llvm::dyn_cast_if_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp());
|
||||||
|
auto argResultIndex = cast<OpResult>(oldCompute.getOperand(indexOld)).getResultNumber();
|
||||||
|
|
||||||
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
||||||
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender();
|
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender(argResultIndex);
|
||||||
assert(isChannel == true);
|
assert(isChannel == true);
|
||||||
spatial::SpatChannelReceiveOp receiveOp =
|
spatial::SpatChannelReceiveOp receiveOp = spatial::SpatChannelReceiveOp::create(
|
||||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelVal);
|
rewriter, loc, oldCompute.getOperand(indexOld).getType(), channelVal);
|
||||||
mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp);
|
mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& op : oldWeightedCompute.getOps()) {
|
for (auto& op : oldCompute.getOps()) {
|
||||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||||
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
|
computeValueResults.innerValues.reserve(yield.getNumOperands());
|
||||||
|
for (Value yieldOperand : yield.getOperands())
|
||||||
|
computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand));
|
||||||
if (lastCompute)
|
if (lastCompute)
|
||||||
rewriter.clone(op, mapper);
|
rewriter.clone(op, mapper);
|
||||||
}
|
}
|
||||||
@@ -199,16 +203,18 @@ private:
|
|||||||
rewriter.clone(op, mapper);
|
rewriter.clone(op, mapper);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& use : llvm::make_early_inc_range(oldWeightedCompute->getUses()))
|
for (auto& use : llvm::make_early_inc_range(oldCompute->getUses()))
|
||||||
if (isa<func::ReturnOp>(use.getOwner()))
|
if (isa<func::ReturnOp>(use.getOwner())) {
|
||||||
use.assign(newWeightedCompute.getResult(0));
|
auto resultIndex = cast<OpResult>(use.get()).getResultNumber();
|
||||||
|
use.assign(newCompute.getResult(resultIndex));
|
||||||
|
}
|
||||||
|
|
||||||
oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute});
|
oldToNewComputeMap.insert({oldCompute, newCompute});
|
||||||
return {cast<SpatWeightedCompute>(newWeightedCompute), computeValueResults};
|
return {cast<SpatCompute>(newCompute), computeValueResults};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<SpatWeightedCompute, ComputeValueResults>
|
std::pair<SpatCompute, ComputeValueResults>
|
||||||
mergeIntoComputeNode(SpatWeightedCompute toCompute, SpatWeightedCompute fromCompute, bool lastCompute) {
|
mergeIntoComputeNode(SpatCompute toCompute, SpatCompute fromCompute, bool lastCompute) {
|
||||||
func::FuncOp func = getOperation();
|
func::FuncOp func = getOperation();
|
||||||
auto loc = func.getLoc();
|
auto loc = func.getLoc();
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
@@ -239,14 +245,15 @@ private:
|
|||||||
// Insert receiveOp
|
// Insert receiveOp
|
||||||
rewriter.setInsertionPointToEnd(&toBB);
|
rewriter.setInsertionPointToEnd(&toBB);
|
||||||
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
|
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
|
||||||
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatWeightedCompute>(arg.getDefiningOp())) {
|
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatCompute>(arg.getDefiningOp())) {
|
||||||
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
||||||
|
auto argResultIndex = cast<OpResult>(arg).getResultNumber();
|
||||||
|
|
||||||
LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal =
|
LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal =
|
||||||
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute);
|
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute, argResultIndex);
|
||||||
if (channelOrLocal.isChannel) {
|
if (channelOrLocal.isChannel) {
|
||||||
spatial::SpatChannelReceiveOp receiveOp =
|
spatial::SpatChannelReceiveOp receiveOp =
|
||||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data);
|
spatial::SpatChannelReceiveOp::create(rewriter, loc, arg.getType(), channelOrLocal.data);
|
||||||
mapper.map(fromBB.getArgument(bbIndex), receiveOp.getResult());
|
mapper.map(fromBB.getArgument(bbIndex), receiveOp.getResult());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
@@ -286,7 +293,9 @@ private:
|
|||||||
};
|
};
|
||||||
for (auto& op : fromCompute.getOps()) {
|
for (auto& op : fromCompute.getOps()) {
|
||||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||||
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
|
computeValueResults.innerValues.reserve(yield.getNumOperands());
|
||||||
|
for (Value yieldOperand : yield.getOperands())
|
||||||
|
computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand));
|
||||||
if (lastCompute)
|
if (lastCompute)
|
||||||
rewriter.clone(op, mapper);
|
rewriter.clone(op, mapper);
|
||||||
}
|
}
|
||||||
@@ -299,33 +308,36 @@ private:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto users : fromCompute->getUsers())
|
for (auto& use : llvm::make_early_inc_range(fromCompute->getUses()))
|
||||||
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
|
if (isa<func::ReturnOp>(use.getOwner())) {
|
||||||
funcRet.setOperand(0, toCompute.getResult(0));
|
auto resultIndex = cast<OpResult>(use.get()).getResultNumber();
|
||||||
|
use.assign(toCompute.getResult(resultIndex));
|
||||||
|
}
|
||||||
|
|
||||||
oldToNewComputeMap.insert({fromCompute, toCompute});
|
oldToNewComputeMap.insert({fromCompute, toCompute});
|
||||||
return {cast<SpatWeightedCompute>(toCompute), computeValueResults};
|
return {cast<SpatCompute>(toCompute), computeValueResults};
|
||||||
}
|
}
|
||||||
|
|
||||||
LazyInsertComputeResult createLazyComputeResult(SpatWeightedCompute weightedCompute,
|
LazyInsertComputeResult createLazyComputeResult(SpatCompute compute,
|
||||||
ComputeValueResults computeValueResults,
|
ComputeValueResults computeValueResults,
|
||||||
bool lastCompute) {
|
bool lastCompute) {
|
||||||
func::FuncOp funcOp = cast<func::FuncOp>(weightedCompute->getParentOp());
|
func::FuncOp funcOp = cast<func::FuncOp>(compute->getParentOp());
|
||||||
auto* context = &getContext();
|
auto* context = &getContext();
|
||||||
auto loc = funcOp.getLoc();
|
auto loc = funcOp.getLoc();
|
||||||
IRRewriter rewriter(context);
|
IRRewriter rewriter(context);
|
||||||
|
|
||||||
rewriter.setInsertionPointToStart(&funcOp.front());
|
rewriter.setInsertionPointToStart(&funcOp.front());
|
||||||
auto savedChannelInsertPoint = rewriter.saveInsertionPoint();
|
auto savedChannelInsertPoint = rewriter.saveInsertionPoint();
|
||||||
auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults]() {
|
auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults](size_t resultIndex) {
|
||||||
IRRewriter rewriter(context);
|
IRRewriter rewriter(context);
|
||||||
rewriter.restoreInsertionPoint(savedChannelInsertPoint);
|
rewriter.restoreInsertionPoint(savedChannelInsertPoint);
|
||||||
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context));
|
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context));
|
||||||
auto channelVal = channelOp.getResult();
|
auto channelVal = channelOp.getResult();
|
||||||
auto insertVal = [&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint sendInsertPoint) {
|
auto insertVal =
|
||||||
|
[&context, loc, computeValueResults, channelVal, resultIndex](mlir::IRRewriter::InsertPoint sendInsertPoint) {
|
||||||
IRRewriter rewriter(context);
|
IRRewriter rewriter(context);
|
||||||
rewriter.restoreInsertionPoint(sendInsertPoint);
|
rewriter.restoreInsertionPoint(sendInsertPoint);
|
||||||
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue);
|
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.get(resultIndex));
|
||||||
return spatSend;
|
return spatSend;
|
||||||
};
|
};
|
||||||
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};
|
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ struct CountInstructionPass : public PassWrapper<CountInstructionPass, Operation
|
|||||||
unsigned totalInstructionCount = 0;
|
unsigned totalInstructionCount = 0;
|
||||||
|
|
||||||
unsigned computeId = 0;
|
unsigned computeId = 0;
|
||||||
for (auto computeOp : func.getOps<spatial::SpatWeightedCompute>()) {
|
for (auto computeOp : func.getOps<spatial::SpatCompute>()) {
|
||||||
unsigned instructionCount = 0;
|
unsigned instructionCount = 0;
|
||||||
instructionCount += computeOp.getBody().front().getOperations().size();
|
instructionCount += computeOp.getBody().front().getOperations().size();
|
||||||
llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n";
|
llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n";
|
||||||
|
|||||||
@@ -26,6 +26,10 @@ STAGE_COUNT = len(STAGE_TITLES)
|
|||||||
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_output_name(name):
|
||||||
|
return "".join(ch if ch.isalnum() or ch in "_.-" else "_" for ch in name[:255])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ValidationResult:
|
class ValidationResult:
|
||||||
passed: bool
|
passed: bool
|
||||||
@@ -205,7 +209,7 @@ def build_dump_ranges(config_path, outputs_descriptor):
|
|||||||
|
|
||||||
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
|
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
|
||||||
run_command(
|
run_command(
|
||||||
["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
|
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
|
||||||
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
|
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
|
||||||
cwd=simulator_dir,
|
cwd=simulator_dir,
|
||||||
reporter=reporter,
|
reporter=reporter,
|
||||||
@@ -229,7 +233,7 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1
|
|||||||
all_passed = True
|
all_passed = True
|
||||||
rows = []
|
rows = []
|
||||||
for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor):
|
for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor):
|
||||||
csv_name = f"output{oi}_{name}.csv"
|
csv_name = f"output{oi}_{sanitize_output_name(name)}.csv"
|
||||||
runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape)
|
runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape)
|
||||||
max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64))))
|
max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64))))
|
||||||
passed = max_diff <= threshold
|
passed = max_diff <= threshold
|
||||||
|
|||||||
Reference in New Issue
Block a user