better spatial IR compaction with better custom syntax, scf.for and
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
spat.map
This commit is contained in:
@@ -111,7 +111,7 @@ static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t&
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
|
||||
SmallVector<int32_t> coreIds;
|
||||
@@ -178,6 +178,43 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan
|
||||
rewriter.replaceOp(receiveManyOp, ValueRange(replacements));
|
||||
}
|
||||
|
||||
static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp,
|
||||
int32_t laneCount,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
auto targetCoreIds = sendManyBatchOp.getTargetCoreIds();
|
||||
for (auto [valueIndex, input] : llvm::enumerate(sendManyBatchOp.getInputs())) {
|
||||
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
|
||||
auto targetSlice = targetCoreIds.slice(metadataOffset, laneCount);
|
||||
pim::PimSendBatchOp::create(rewriter,
|
||||
sendManyBatchOp.getLoc(),
|
||||
mapper.lookup(input),
|
||||
getTensorSizeInBytesAttr(rewriter, mapper.lookup(input)),
|
||||
rewriter.getDenseI32ArrayAttr(targetSlice));
|
||||
}
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
|
||||
int32_t laneCount,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
auto sourceCoreIds = receiveManyBatchOp.getSourceCoreIds();
|
||||
for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) {
|
||||
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
|
||||
auto sourceSlice = sourceCoreIds.slice(metadataOffset, laneCount);
|
||||
auto outputType = cast<ShapedType>(output.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType);
|
||||
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
||||
receiveManyBatchOp.getLoc(),
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, output),
|
||||
rewriter.getDenseI32ArrayAttr(sourceSlice))
|
||||
.getOutput();
|
||||
mapper.map(output, received);
|
||||
}
|
||||
}
|
||||
|
||||
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
||||
Value input = extractRowsOp.getInput();
|
||||
RankedTensorType inputType;
|
||||
@@ -226,6 +263,56 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
|
||||
rewriter.replaceOp(concatOp, concatenated);
|
||||
}
|
||||
|
||||
static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatWeightedVMMOp> wvmmOps;
|
||||
funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) {
|
||||
if (wvmmOp->getParentOfType<pim::PimCoreOp>() || wvmmOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
wvmmOps.push_back(wvmmOp);
|
||||
});
|
||||
|
||||
for (auto wvmmOp : wvmmOps) {
|
||||
rewriter.setInsertionPoint(wvmmOp);
|
||||
auto outputType = cast<ShapedType>(wvmmOp.getOutput().getType());
|
||||
Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult();
|
||||
rewriter.replaceOpWithNewOp<pim::PimVMMOp>(wvmmOp,
|
||||
wvmmOp.getOutput().getType(),
|
||||
rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()),
|
||||
wvmmOp.getInput(),
|
||||
outputBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
static void expandMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatMapOp> mapOps;
|
||||
funcOp.walk([&](spatial::SpatMapOp mapOp) { mapOps.push_back(mapOp); });
|
||||
|
||||
for (auto mapOp : mapOps) {
|
||||
Block& body = mapOp.getBody().front();
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
|
||||
|
||||
SmallVector<Value> replacements;
|
||||
replacements.reserve(mapOp.getInputs().size());
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
for (Value input : mapOp.getInputs()) {
|
||||
IRMapping mapping;
|
||||
mapping.map(body.getArgument(0), input);
|
||||
|
||||
Value replacement = input;
|
||||
for (Operation& op : body.without_terminator()) {
|
||||
Operation* cloned = rewriter.clone(op, mapping);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapping.map(originalResult, clonedResult);
|
||||
rewriter.setInsertionPointAfter(cloned);
|
||||
}
|
||||
|
||||
replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
||||
replacements.push_back(replacement);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(mapOp, replacements);
|
||||
}
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
SmallVectorImpl<Operation*>& helperChain,
|
||||
bool requireReturnUse = true) {
|
||||
@@ -551,6 +638,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
func::FuncOp funcOp = *entryFunc;
|
||||
|
||||
IRRewriter rewriter(&getContext());
|
||||
expandMapOps(funcOp, rewriter);
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect,
|
||||
@@ -640,6 +728,32 @@ void SpatialToPimPass::runOnOperation() {
|
||||
for (auto extractRowsOp : extractRowsOps)
|
||||
lowerExtractRows(extractRowsOp, rewriter);
|
||||
|
||||
{
|
||||
RewritePatternSet coreBodyPatterns(ctx);
|
||||
populateWithGenerated(coreBodyPatterns);
|
||||
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
||||
|
||||
SmallVector<pim::PimCoreOp> coreOps;
|
||||
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
||||
for (auto coreOp : coreOps) {
|
||||
if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
||||
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
||||
for (auto coreBatchOp : coreBatchOps) {
|
||||
if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lowerRemainingSpatialMathOps(funcOp, rewriter);
|
||||
|
||||
RewritePatternSet channelPatterns(ctx);
|
||||
populateWithGenerated(channelPatterns);
|
||||
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
||||
@@ -939,7 +1053,7 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
|
||||
ValueRange(batchInputs));
|
||||
coreBatchOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
@@ -1000,6 +1114,11 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendManyBatchOp = dyn_cast<spatial::SpatChannelSendManyBatchOp>(op)) {
|
||||
lowerChannelSendManyBatch(sendManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
||||
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
||||
@@ -1014,6 +1133,11 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveManyBatchOp = dyn_cast<spatial::SpatChannelReceiveManyBatchOp>(op)) {
|
||||
lowerChannelReceiveManyBatch(receiveManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
||||
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
|
||||
Reference in New Issue
Block a user