diff --git a/src/PIM/Common/PimCommon.hpp b/src/PIM/Common/PimCommon.hpp index 59c5ee1..6880012 100644 --- a/src/PIM/Common/PimCommon.hpp +++ b/src/PIM/Common/PimCommon.hpp @@ -22,6 +22,7 @@ namespace onnx_mlir { -inline constexpr llvm::StringLiteral kCoreIdAttrName = "core_id"; +inline constexpr llvm::StringLiteral kCoreIdAttrName = "coreId"; +inline constexpr llvm::StringLiteral kCoreIdsAttrName = "coreIds"; } // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index c6606b9..7a40d44 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -517,8 +517,8 @@ static SmallVector getUsedWeightIndices(pim::PimCoreOp coreOp) { } static SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { - auto coreIdsAttr = coreBatchOp->getAttrOfType(onnx_mlir::kCoreIdAttrName); - assert(coreIdsAttr && "pim.core_batch requires core_id array attribute"); + auto coreIdsAttr = coreBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); + assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute"); return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); } diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 5ec197b..e0f44c8 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -111,7 +111,7 @@ static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& } static SmallVector getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { - if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) + if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); SmallVector coreIds; @@ -178,6 +178,43 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan rewriter.replaceOp(receiveManyOp, ValueRange(replacements)); } +static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp, + int32_t laneCount, + IRMapping& mapper, + IRRewriter& rewriter) { + auto targetCoreIds = sendManyBatchOp.getTargetCoreIds(); + for (auto [valueIndex, input] : llvm::enumerate(sendManyBatchOp.getInputs())) { + size_t metadataOffset = valueIndex * static_cast(laneCount); + auto targetSlice = targetCoreIds.slice(metadataOffset, laneCount); + pim::PimSendBatchOp::create(rewriter, + sendManyBatchOp.getLoc(), + mapper.lookup(input), + getTensorSizeInBytesAttr(rewriter, mapper.lookup(input)), + rewriter.getDenseI32ArrayAttr(targetSlice)); + } +} + +static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp, + int32_t laneCount, + IRMapping& mapper, + IRRewriter& rewriter) { + auto sourceCoreIds = receiveManyBatchOp.getSourceCoreIds(); + for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) { + size_t metadataOffset = valueIndex * static_cast(laneCount); + auto sourceSlice = sourceCoreIds.slice(metadataOffset, laneCount); + auto outputType = cast(output.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType); + auto received = pim::PimReceiveBatchOp::create(rewriter, + receiveManyBatchOp.getLoc(), + outputBuffer.getType(), + outputBuffer, + getTensorSizeInBytesAttr(rewriter, output), + rewriter.getDenseI32ArrayAttr(sourceSlice)) + .getOutput(); + mapper.map(output, received); + } +} + static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) { Value input = extractRowsOp.getInput(); RankedTensorType inputType; @@ -226,6 +263,56 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) { rewriter.replaceOp(concatOp, concatenated); } +static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) { + SmallVector wvmmOps; + funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) { + if (wvmmOp->getParentOfType() || wvmmOp->getParentOfType()) + wvmmOps.push_back(wvmmOp); + }); + + for (auto wvmmOp : wvmmOps) { + rewriter.setInsertionPoint(wvmmOp); + auto outputType = cast(wvmmOp.getOutput().getType()); + Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult(); + rewriter.replaceOpWithNewOp(wvmmOp, + wvmmOp.getOutput().getType(), + rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()), + wvmmOp.getInput(), + outputBuffer); + } +} + +static void expandMapOps(func::FuncOp funcOp, IRRewriter& rewriter) { + SmallVector mapOps; + funcOp.walk([&](spatial::SpatMapOp mapOp) { mapOps.push_back(mapOp); }); + + for (auto mapOp : mapOps) { + Block& body = mapOp.getBody().front(); + auto yieldOp = cast(body.getTerminator()); + + SmallVector replacements; + replacements.reserve(mapOp.getInputs().size()); + rewriter.setInsertionPoint(mapOp); + for (Value input : mapOp.getInputs()) { + IRMapping mapping; + mapping.map(body.getArgument(0), input); + + Value replacement = input; + for (Operation& op : body.without_terminator()) { + Operation* cloned = rewriter.clone(op, mapping); + for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapping.map(originalResult, clonedResult); + rewriter.setInsertionPointAfter(cloned); + } + + replacement = mapping.lookupOrDefault(yieldOp.getOperand(0)); + replacements.push_back(replacement); + } + + rewriter.replaceOp(mapOp, replacements); + } +} + static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, SmallVectorImpl& helperChain, bool requireReturnUse = true) { @@ -551,6 +638,7 @@ void SpatialToPimPass::runOnOperation() { func::FuncOp funcOp = *entryFunc; IRRewriter rewriter(&getContext()); + expandMapOps(funcOp, rewriter); ConversionTarget target(*ctx); target.addLegalDialect coreOps; + funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); + for (auto coreOp : coreOps) { + if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) { + signalPassFailure(); + return; + } + } + + SmallVector coreBatchOps; + funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); + for (auto coreBatchOp : coreBatchOps) { + if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) { + signalPassFailure(); + return; + } + } + } + + lowerRemainingSpatialMathOps(funcOp, rewriter); + RewritePatternSet channelPatterns(ctx); populateWithGenerated(channelPatterns); if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) { @@ -939,7 +1053,7 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc ValueRange(batchInputs)); coreBatchOp.getProperties().setOperandSegmentSizes( {static_cast(batchWeights.size()), static_cast(batchInputs.size())}); - coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); SmallVector blockArgTypes; SmallVector blockArgLocs; @@ -1000,6 +1114,11 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc continue; } + if (auto sendManyBatchOp = dyn_cast(op)) { + lowerChannelSendManyBatch(sendManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter); + continue; + } + if (auto receiveBatchOp = dyn_cast(op)) { auto outputType = cast(receiveBatchOp.getOutput().getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType); @@ -1014,6 +1133,11 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc continue; } + if (auto receiveManyBatchOp = dyn_cast(op)) { + lowerChannelReceiveManyBatch(receiveManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter); + continue; + } + if (auto toTensorOp = dyn_cast(op)) { if (isa_and_present(toTensorOp.getBuffer().getDefiningOp())) { Operation* cloned = rewriter.clone(op, mapper); diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 7a2b7d8..3ecc353 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -39,7 +39,7 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> { }]; } -def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments]> { +def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSizedOperandSegments]> { let summary = "Execute equivalent batched core bodies"; let regions = (region SizedRegion<1>:$body); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index c9a6769..481d0af 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -257,8 +257,8 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel(weights.size()), static_cast(inputs.size())}); - if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdAttrName)) - newOp->setAttr(onnx_mlir::kCoreIdAttrName, coreIdsAttr); + if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdsAttrName)) + newOp->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr); rewriter.inlineRegionBefore(coreBatchOp.getBody(), newOp.getBody(), newOp.getBody().begin()); for (Block& block : newOp.getBody()) diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index 50f5ce0..6f6a7ff 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -8,6 +8,7 @@ add_pim_library(SpatialOps SpatialOpsVerify.cpp SpatialOpsCanonicalization.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp + Transforms/MergeComputeNodes/RegularOpCompaction.cpp Transforms/MergeComputeNodes/DCPGraph/Graph.cpp Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 238472f..9a074a6 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -102,6 +102,23 @@ def SpatConcatOp : SpatOp<"concat", []> { let hasCustomAssemblyFormat = 1; } +def SpatMapOp : SpatOp<"map", [SingleBlock]> { + let summary = "Apply the same lane-local region to many independent tensors"; + + let arguments = (ins + Variadic:$inputs + ); + + let results = (outs + Variadic:$outputs + ); + + let regions = (region SizedRegion<1>:$body); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Communication //===----------------------------------------------------------------------===// @@ -184,6 +201,20 @@ def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> { let hasCustomAssemblyFormat = 1; } +def SpatChannelSendManyBatchOp : SpatOp<"channel_send_many_batch", []> { + let summary = "Send multiple per-lane tensors through logical channels in a batch body"; + + let arguments = (ins + DenseI64ArrayAttr:$channelIds, + DenseI32ArrayAttr:$sourceCoreIds, + DenseI32ArrayAttr:$targetCoreIds, + Variadic:$inputs + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { let summary = "Receive a per-lane tensor through logical channels in a batch body"; @@ -201,11 +232,28 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { let hasCustomAssemblyFormat = 1; } +def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> { + let summary = "Receive multiple per-lane tensors through logical channels in a batch body"; + + let arguments = (ins + DenseI64ArrayAttr:$channelIds, + DenseI32ArrayAttr:$sourceCoreIds, + DenseI32ArrayAttr:$targetCoreIds + ); + + let results = (outs + Variadic:$outputs + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Math //===----------------------------------------------------------------------===// -def SpatWeightedVMMOp : SpatOp<"Wvmm", []> { +def SpatWeightedVMMOp : SpatOp<"wvmm", []> { let summary = "Vector-matrix multiplication within a weighted compute operation"; let arguments = (ins diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index 022b623..d9ac3c5 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -42,6 +42,10 @@ static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) printer << (delimiter == ListDelimiter::Square ? "]" : ")"); } +static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) { + return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy)); +} + template static ParseResult parseCompressedRepeatedList(OpAsmParser& parser, ListDelimiter delimiter, @@ -75,51 +79,65 @@ static ParseResult parseCompressedRepeatedList(OpAsmParser& parser, } template -static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl& values) { - if (parser.parseLSquare()) - return failure(); - if (succeeded(parser.parseOptionalRSquare())) +static ParseResult +parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& values) { + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); while (true) { - int64_t first = 0; - if (parser.parseInteger(first)) - return failure(); + if (succeeded(parser.parseOptionalLParen())) { + SmallVector subgroup; + if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup)) + return failure(); - if (succeeded(parser.parseOptionalKeyword("to"))) { - int64_t last = 0; - if (parser.parseInteger(last) || last < first) - return parser.emitError(parser.getCurrentLocation(), "invalid ascending range"); - - int64_t step = 1; - if (succeeded(parser.parseOptionalKeyword("by"))) { - if (parser.parseInteger(step) || step <= 0) - return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive"); - } int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } - if ((last - first) % step != 0) - return parser.emitError(parser.getCurrentLocation(), - "range end must be reachable from start using the given step"); - - for (int64_t value = first; value <= last; value += step) - for (int64_t index = 0; index < repeatCount; ++index) - values.push_back(static_cast(value)); + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(values, subgroup); } else { - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + int64_t first = 0; + if (parser.parseInteger(first)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("to"))) { + int64_t last = 0; + if (parser.parseInteger(last) || last < first) + return parser.emitError(parser.getCurrentLocation(), "invalid ascending range"); + + int64_t step = 1; + if (succeeded(parser.parseOptionalKeyword("by"))) { + if (parser.parseInteger(step) || step <= 0) + return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive"); + } + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + if ((last - first) % step != 0) + return parser.emitError(parser.getCurrentLocation(), + "range end must be reachable from start using the given step"); + + for (int64_t value = first; value <= last; value += step) + for (int64_t index = 0; index < repeatCount; ++index) + values.push_back(static_cast(value)); + } + else { + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + values.push_back(static_cast(first)); } - for (int64_t index = 0; index < repeatCount; ++index) - values.push_back(static_cast(first)); } - if (succeeded(parser.parseOptionalRSquare())) + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) break; if (parser.parseComma()) return failure(); @@ -128,6 +146,14 @@ static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorIm return success(); } +template +static ParseResult +parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& values) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + return parseCompressedIntegerEntries(parser, delimiter, values); +} + template static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) { for (size_t index = 0; index < entries.size();) { @@ -146,35 +172,51 @@ static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, Prin } template -static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef values) { - printer << "["; - for (size_t index = 0; index < values.size();) { - if (index != 0) - printer << ", "; - - auto findEqualRunEnd = [&](size_t start) { - size_t end = start + 1; - while (end < values.size() && values[end] == values[start]) - ++end; - return end; +static void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef values, ListDelimiter delimiter) { + struct FlatCompression { + enum class Kind { + Single, + EqualRun, + Progression }; - size_t firstRunEnd = findEqualRunEnd(index); - size_t repeatCount = firstRunEnd - index; + Kind kind = Kind::Single; + size_t covered = 1; + size_t repeatCount = 1; + size_t progressionValueCount = 1; + int64_t step = 1; + IntT firstValue {}; + IntT lastValue {}; + }; + + auto computeFlatCompression = [&](size_t start) { + FlatCompression compression; + compression.firstValue = values[start]; + compression.lastValue = values[start]; + + auto findEqualRunEnd = [&](size_t runStart) { + size_t runEnd = runStart + 1; + while (runEnd < values.size() && values[runEnd] == values[runStart]) + ++runEnd; + return runEnd; + }; + + size_t firstRunEnd = findEqualRunEnd(start); + compression.repeatCount = firstRunEnd - start; size_t progressionEnd = firstRunEnd; int64_t step = 0; - IntT lastValue = values[index]; + IntT lastValue = values[start]; if (firstRunEnd < values.size()) { size_t secondRunEnd = findEqualRunEnd(firstRunEnd); - step = static_cast(values[firstRunEnd]) - static_cast(values[index]); - if (step > 0 && secondRunEnd - firstRunEnd == repeatCount) { + step = static_cast(values[firstRunEnd]) - static_cast(values[start]); + if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) { progressionEnd = secondRunEnd; lastValue = values[firstRunEnd]; size_t currentRunStart = secondRunEnd; while (currentRunStart < values.size()) { size_t currentRunEnd = findEqualRunEnd(currentRunStart); - if (currentRunEnd - currentRunStart != repeatCount) + if (currentRunEnd - currentRunStart != compression.repeatCount) break; if (static_cast(values[currentRunStart]) != static_cast(lastValue) + step) break; @@ -188,27 +230,99 @@ static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef val } } - size_t progressionValueCount = repeatCount == 0 ? 0 : (progressionEnd - index) / repeatCount; - if (progressionEnd > firstRunEnd && progressionValueCount >= 3) { - printer << values[index] << " to " << lastValue; - if (step != 1) - printer << " by " << step; - if (repeatCount > 1) - printer << " x" << repeatCount; - index = progressionEnd; - continue; + compression.covered = 1; + if (progressionEnd > firstRunEnd) { + size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount; + if (progressionValueCount >= 3) { + compression.kind = FlatCompression::Kind::Progression; + compression.covered = progressionEnd - start; + compression.progressionValueCount = progressionValueCount; + compression.step = step; + compression.lastValue = lastValue; + return compression; + } } - if (repeatCount > 1) { - printer << values[index] << " x" << repeatCount; - index = firstRunEnd; - continue; + if (compression.repeatCount > 1) { + compression.kind = FlatCompression::Kind::EqualRun; + compression.covered = compression.repeatCount; + return compression; } - printer << values[index]; - index = firstRunEnd; + return compression; + }; + + auto findRepeatedSublist = [&](size_t start) { + size_t bestLength = 0; + size_t bestRepeatCount = 1; + size_t remaining = values.size() - start; + + for (size_t length = 2; length * 2 <= remaining; ++length) { + size_t repeatCount = 1; + ArrayRef candidate = values.slice(start, length); + while (start + (repeatCount + 1) * length <= values.size() + && llvm::equal(candidate, values.slice(start + repeatCount * length, length))) { + ++repeatCount; + } + + if (repeatCount <= 1) + continue; + + size_t covered = length * repeatCount; + size_t bestCovered = bestLength * bestRepeatCount; + if (covered > bestCovered || (covered == bestCovered && length < bestLength)) { + bestLength = length; + bestRepeatCount = repeatCount; + } + } + + return std::pair(bestLength, bestRepeatCount); + }; + + printOpenDelimiter(printer, delimiter); + for (size_t index = 0; index < values.size();) { + if (index != 0) + printer << ", "; + + FlatCompression flat = computeFlatCompression(index); + auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index); + size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount; + if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) { + printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren); + printer << " x" << sublistRepeatCount; + index += repeatedSublistCoverage; + continue; + } + switch (flat.kind) { + case FlatCompression::Kind::Progression: + printer << flat.firstValue << " to " << flat.lastValue; + if (flat.step != 1) + printer << " by " << flat.step; + if (flat.repeatCount > 1) + printer << " x" << flat.repeatCount; + index += flat.covered; + break; + case FlatCompression::Kind::EqualRun: + printer << flat.firstValue << " x" << flat.repeatCount; + index += flat.covered; + break; + case FlatCompression::Kind::Single: + printer << flat.firstValue; + index += flat.covered; + break; + } } - printer << "]"; + printCloseDelimiter(printer, delimiter); +} + +template +static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl& values) { + return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values); +} + +template +static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef values) { + printCompressedIntegerSequence(printer, values, ListDelimiter::Square); } static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) { @@ -267,6 +381,165 @@ static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, List printCloseDelimiter(printer, delimiter); } +static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser, + SmallVectorImpl& operands); +static ParseResult parseCompressedOperandSequence(OpAsmParser& parser, + SmallVectorImpl& operands); +static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl& types, bool allowEmpty); + +static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) { + if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0) + return false; + + SmallVector valueVec(values.begin(), values.end()); + ArrayRef tuple(valueVec.data(), tupleSize); + for (size_t index = tupleSize; index < values.size(); index += tupleSize) + if (!llvm::equal(tuple, ArrayRef(valueVec).slice(index, tupleSize))) + return false; + return true; +} + +static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) { + if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0) + return false; + + SmallVector typeVec(types.begin(), types.end()); + ArrayRef tuple(typeVec.data(), tupleSize); + for (size_t index = tupleSize; index < types.size(); index += tupleSize) + if (!llvm::equal(tuple, ArrayRef(typeVec).slice(index, tupleSize))) + return false; + return true; +} + +static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) { + printer << "["; + printOpenDelimiter(printer, ListDelimiter::Paren); + for (size_t index = 0; index < tupleSize; ++index) { + if (index != 0) + printer << ", "; + printer.printOperand(values[index]); + } + printCloseDelimiter(printer, ListDelimiter::Paren); + printer << " x" << (values.size() / tupleSize) << "]"; +} + +static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) { + printer << "["; + printOpenDelimiter(printer, ListDelimiter::Paren); + for (size_t index = 0; index < tupleSize; ++index) { + if (index != 0) + printer << ", "; + printer.printType(types[index]); + } + printCloseDelimiter(printer, ListDelimiter::Paren); + printer << " x" << (types.size() / tupleSize) << "]"; +} + +static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser, + SmallVectorImpl& operands) { + if (parser.parseLSquare()) + return failure(); + if (succeeded(parser.parseOptionalRSquare())) + return success(); + + if (succeeded(parser.parseOptionalLParen())) { + SmallVector tupleOperands; + if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen()) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(operands, tupleOperands); + + while (succeeded(parser.parseOptionalComma())) { + if (parser.parseLParen()) + return failure(); + tupleOperands.clear(); + if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen()) + return failure(); + + repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(operands, tupleOperands); + } + return parser.parseRSquare(); + } + + while (true) { + if (parseOneCompressedOperandEntry(parser, operands)) + return failure(); + if (succeeded(parser.parseOptionalRSquare())) + return success(); + if (parser.parseComma()) + return failure(); + } +} + +static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl& types) { + if (parser.parseLSquare()) + return failure(); + if (succeeded(parser.parseOptionalRSquare())) + return success(); + + if (succeeded(parser.parseOptionalLParen())) { + SmallVector tupleTypes; + if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen()) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(types, tupleTypes); + + while (succeeded(parser.parseOptionalComma())) { + if (parser.parseLParen()) + return failure(); + tupleTypes.clear(); + if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen()) + return failure(); + + repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(types, tupleTypes); + } + return parser.parseRSquare(); + } + + while (true) { + Type type; + if (parser.parseType(type)) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + types.push_back(type); + + if (succeeded(parser.parseOptionalRSquare())) + return success(); + if (parser.parseComma()) + return failure(); + } +} + static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser, OpAsmParser::UnresolvedOperand firstOperand, SmallVectorImpl& operands) { @@ -440,19 +713,88 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { return parser.getBuilder().getI32IntegerAttr(value); } -static void buildImplicitRegionArgs(OpAsmParser& parser, - ArrayRef inputTypes, - SmallVectorImpl& generatedNames, - SmallVectorImpl& arguments) { - generatedNames.reserve(inputTypes.size()); - arguments.reserve(inputTypes.size()); - for (auto [index, inputType] : llvm::enumerate(inputTypes)) { - generatedNames.push_back("arg" + std::to_string(index + 1)); - OpAsmParser::Argument arg; - arg.ssaName = {parser.getCurrentLocation(), generatedNames.back(), 0}; - arg.type = inputType; - arguments.push_back(arg); +static void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) { + if (block.getNumArguments() == 0) { + printer << "() = ()"; + return; } + + if (block.getNumArguments() == 1) { + printer.printOperand(block.getArgument(0)); + printer << " = "; + printCompressedValueList(printer, operands, ListDelimiter::Paren); + return; + } + + printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren); + printer << " = "; + printCompressedValueList(printer, operands, ListDelimiter::Paren); +} + +static ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser, + OpAsmParser::Argument firstArgument, + SmallVectorImpl& arguments) { + if (succeeded(parser.parseOptionalKeyword("to"))) { + OpAsmParser::Argument lastArgument; + if (parser.parseArgument(lastArgument)) + return failure(); + if (firstArgument.ssaName.name != lastArgument.ssaName.name + || firstArgument.ssaName.number > lastArgument.ssaName.number) { + return parser.emitError(parser.getCurrentLocation(), "invalid argument range"); + } + for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) { + OpAsmParser::Argument argument; + argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number}; + arguments.push_back(argument); + } + return success(); + } + + arguments.push_back(firstArgument); + return success(); +} + +static ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser, + SmallVectorImpl& arguments) { + OpAsmParser::Argument firstArgument; + if (parser.parseArgument(firstArgument)) + return failure(); + return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments); +} + +static void applyArgumentTypes(ArrayRef inputTypes, SmallVectorImpl& arguments) { + for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes)) + argument.type = inputType; +} + +static ParseResult parseArgumentBindings(OpAsmParser& parser, + SmallVectorImpl& arguments, + SmallVectorImpl& operands) { + if (succeeded(parser.parseOptionalLParen())) { + if (succeeded(parser.parseOptionalRParen())) { + if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) + return failure(); + return success(); + } + + OpAsmParser::Argument firstArgument; + if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments)) + return failure(); + while (succeeded(parser.parseOptionalComma())) + if (parseOneCompressedArgumentEntry(parser, arguments)) + return failure(); + if (parser.parseRParen() || parser.parseEqual() + || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) + return failure(); + return success(); + } + + OpAsmParser::Argument argument; + if (parser.parseArgument(argument) || parser.parseEqual() + || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) + return failure(); + arguments.push_back(argument); + return success(); } } // namespace @@ -519,8 +861,8 @@ ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result void SpatConcatOp::print(OpAsmPrinter& printer) { printer << " axis " << getAxis(); - printer << " args = "; - printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); + printer << " "; + printCompressedValueSequence(printer, getInputs()); printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()}); printer << " : "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); @@ -537,11 +879,7 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { if (parser.parseKeyword("axis") || parser.parseInteger(axis)) return failure(); - if (succeeded(parser.parseOptionalKeyword("args"))) { - if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) - return failure(); - } - else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { + if (parseCompressedOperandSequence(parser, inputs)) { return failure(); } @@ -563,14 +901,54 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { return success(); } +void SpatMapOp::print(OpAsmPrinter& printer) { + printer << " "; + printArgumentBindings(printer, getBody().front(), getInputs()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : "; + printer.printType(getInputs().front().getType()); + printer << " -> "; + printer.printType(getOutputs().front().getType()); + printer << " "; + printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); +} + +ParseResult SpatMapOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector regionArgs; + SmallVector inputs; + Type inputType; + Type outputType; + + if (parseArgumentBindings(parser, regionArgs, inputs)) + return failure(); + if (inputs.empty()) + return parser.emitError(parser.getCurrentLocation(), "map requires at least one input"); + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType) + || parser.parseArrow() || parser.parseType(outputType)) + return failure(); + + SmallVector inputTypes(inputs.size(), inputType); + SmallVector outputTypes(inputs.size(), outputType); + if (regionArgs.size() != inputs.size()) + return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); + if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputTypes); + + applyArgumentTypes(inputTypes, regionArgs); + Region* body = result.addRegion(); + return parser.parseRegion(*body, regionArgs); +} + void SpatCompute::print(OpAsmPrinter& printer) { printer << " "; printCompressedValueList(printer, getWeights(), ListDelimiter::Square); - printer << " args = "; - printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); + printer << " "; + printArgumentBindings(printer, getBody().front(), getInputs()); if (auto coreIdAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) - printer << " core_id " << coreIdAttr.getInt(); + printer << " coreId " << coreIdAttr.getInt(); printer.printOptionalAttrDict((*this)->getAttrs(), {getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); @@ -587,7 +965,6 @@ void SpatCompute::print(OpAsmPrinter& printer) { ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { SmallVector regionArgs; - SmallVector generatedArgNames; SmallVector weights; SmallVector inputs; SmallVector weightTypes; @@ -598,15 +975,10 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) return failure(); - if (succeeded(parser.parseOptionalKeyword("args"))) { - if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) - return failure(); - } - else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { + if (parseArgumentBindings(parser, regionArgs, inputs)) return failure(); - } - bool hasCoreId = succeeded(parser.parseOptionalKeyword("core_id")); + bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id"); if (hasCoreId && parser.parseInteger(coreId)) return failure(); @@ -622,9 +994,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (regionArgs.size() != inputs.size()) + return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName)) return parser.emitError(parser.getCurrentLocation(), - "core_id cannot be specified both positionally and in attr-dict"); + "coreId cannot be specified both positionally and in attr-dict"); auto& builder = parser.getBuilder(); result.addAttribute( @@ -639,27 +1013,34 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { result.addTypes(outputTypes); Region* body = result.addRegion(); - buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs); + applyArgumentTypes(inputTypes, regionArgs); return parser.parseRegion(*body, regionArgs); } void SpatComputeBatch::print(OpAsmPrinter& printer) { printer << " lanes " << getLaneCount() << " "; - printCompressedValueList(printer, getWeights(), ListDelimiter::Square); - printer << " args = "; - printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); + size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast(getLaneCount()) : 0; + if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane)) + printValueTupleRun(printer, getWeights(), weightsPerLane); + else + printCompressedValueList(printer, getWeights(), ListDelimiter::Square); + printer << " "; + printArgumentBindings(printer, getBody().front(), getInputs()); - if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { - printer << " core_ids "; + if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) { + printer << " coreIds "; printCompressedIntegerList(printer, coreIdsAttr.asArrayRef()); } printer.printOptionalAttrDict( (*this)->getAttrs(), - {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); + {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName}); printer << " : "; - printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); + if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane)) + printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane); + else + printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); printer << " "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); printer << " -> "; @@ -671,7 +1052,6 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) { ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) { int32_t laneCount = 0; SmallVector regionArgs; - SmallVector generatedArgNames; SmallVector weights; SmallVector inputs; SmallVector weightTypes; @@ -682,24 +1062,18 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)) return failure(); - if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) + if (parseCompressedOrTupleOperandList(parser, weights)) return failure(); - if (succeeded(parser.parseOptionalKeyword("args"))) { - if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) - return failure(); - } - else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { + if (parseArgumentBindings(parser, regionArgs, inputs)) return failure(); - } - bool hasCoreIds = succeeded(parser.parseOptionalKeyword("core_ids")); + bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids"); if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedRepeatedList( - parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) + || parseCompressedOrTupleTypeList(parser, weightTypes) || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) @@ -709,8 +1083,11 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdAttrName)) - return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both in core_ids and attr-dict"); + if (regionArgs.size() != inputs.size()) + return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); + if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName)) + return parser.emitError(parser.getCurrentLocation(), + "coreIds cannot be specified both positionally and in attr-dict"); auto& builder = parser.getBuilder(); result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); @@ -718,7 +1095,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) "operandSegmentSizes", builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); if (hasCoreIds) - result.addAttribute(onnx_mlir::kCoreIdAttrName, getDenseI32ArrayAttr(parser, coreIds)); + result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds)); if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) @@ -726,7 +1103,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) result.addTypes(outputTypes); Region* body = result.addRegion(); - buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs); + applyArgumentTypes(inputTypes, regionArgs); return parser.parseRegion(*body, regionArgs); } @@ -867,6 +1244,55 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r return parser.resolveOperand(input, inputType, result.operands); } +void SpatChannelSendManyBatchOp::print(OpAsmPrinter& printer) { + printer << " "; + printCompressedValueSequence(printer, getInputs()); + printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printCompressedTypeSequence(printer, TypeRange(getInputs())); +} + +ParseResult SpatChannelSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector inputs; + SmallVector inputTypes; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + if (parseCompressedOperandSequence(parser, inputs)) + return failure(); + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false)) + return failure(); + + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); +} + void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) { printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); printer.printOptionalAttrDict( @@ -908,5 +1334,47 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState return success(); } +void SpatChannelReceiveManyBatchOp::print(OpAsmPrinter& printer) { + printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printCompressedTypeSequence(printer, getResultTypes()); +} + +ParseResult SpatChannelReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector outputTypes; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) + return failure(); + + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + result.addTypes(outputTypes); + return success(); +} + } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index b4d8657..6dc872b 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -83,13 +83,13 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter, } static FailureOr> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) { - if (auto computeOp = dyn_cast(weightedOp->getParentOp())) + if (auto computeOp = weightedOp->getParentOfType()) return cast(computeOp.getWeights()[weightIndex].getType()).getShape(); - if (auto coreOp = dyn_cast(weightedOp->getParentOp())) + if (auto coreOp = weightedOp->getParentOfType()) return cast(coreOp.getWeights()[weightIndex].getType()).getShape(); - if (auto batchOp = dyn_cast(weightedOp->getParentOp())) { + if (auto batchOp = weightedOp->getParentOfType()) { if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size()) return failure(); return cast(batchOp.getWeights()[weightIndex].getType()).getShape(); @@ -144,6 +144,23 @@ static LogicalResult verifyBatchChannelSizes(Operation* op, return success(); } +static LogicalResult verifyManyBatchChannelSizes(Operation* op, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + size_t valueCount) { + if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) + return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); + + auto laneCount = getParentBatchLaneCount(op); + if (failed(laneCount)) + return op->emitError("must be nested inside spat.compute_batch"); + if (channelIds.size() != valueCount * static_cast(*laneCount)) + return op->emitError("channel metadata length must match the number of values times parent laneCount"); + + return success(); +} + static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) { auto yieldOp = dyn_cast_or_null(block.getTerminator()); if (!yieldOp) @@ -306,6 +323,39 @@ LogicalResult SpatConcatOp::verify() { return success(); } +LogicalResult SpatMapOp::verify() { + if (getInputs().empty()) + return emitError("requires at least one input"); + if (getOutputs().size() != getInputs().size()) + return emitError("number of outputs must match number of inputs"); + + Type inputType = getInputs().front().getType(); + for (Value input : getInputs().drop_front()) + if (input.getType() != inputType) + return emitError("all inputs must have the same type"); + + Type outputType = getOutputs().front().getType(); + for (Value output : getOutputs().drop_front()) + if (output.getType() != outputType) + return emitError("all outputs must have the same type"); + + Block& block = getBody().front(); + if (block.getNumArguments() != 1) + return emitError("body must have exactly one block argument"); + if (block.getArgument(0).getType() != inputType) + return emitError("body block argument type must match input type"); + + auto yieldOp = dyn_cast_or_null(block.getTerminator()); + if (!yieldOp) + return emitError("body must terminate with spat.yield"); + if (yieldOp.getNumOperands() != 1) + return emitError("body yield must produce exactly one value"); + if (yieldOp.getOperand(0).getType() != outputType) + return emitError("body yield type must match output type"); + + return success(); +} + LogicalResult SpatCompute::verify() { auto& block = getBody().front(); if (block.mightHaveTerminator()) { @@ -365,10 +415,24 @@ LogicalResult SpatChannelSendBatchOp::verify() { return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); } +LogicalResult SpatChannelSendManyBatchOp::verify() { + if (failed(verifyManyBatchChannelSizes( + getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size()))) + return failure(); + return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many_batch"); +} + LogicalResult SpatChannelReceiveBatchOp::verify() { return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); } +LogicalResult SpatChannelReceiveManyBatchOp::verify() { + if (failed(verifyManyBatchChannelSizes( + getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size()))) + return failure(); + return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many_batch"); +} + LogicalResult SpatComputeBatch::verify() { int32_t count = getLaneCount(); if (count <= 0) @@ -405,18 +469,18 @@ LogicalResult SpatComputeBatch::verify() { return emitError("all outputs must have the same type"); } - if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdAttrName)) { + if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdsAttrName)) { auto coreIdsAttr = dyn_cast(coreIdAttr); if (!coreIdsAttr) - return emitError("compute_batch core_id attribute must be a dense i32 array"); + return emitError("compute_batch coreIds attribute must be a dense i32 array"); if (coreIdsAttr.size() != laneCountSz) - return emitError("compute_batch core_id array length must match laneCount"); + return emitError("compute_batch coreIds array length must match laneCount"); if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; })) - return emitError("compute_batch core_id values must be positive"); + return emitError("compute_batch coreIds values must be positive"); llvm::SmallDenseSet seenCoreIds; for (int32_t coreId : coreIdsAttr.asArrayRef()) if (!seenCoreIds.insert(coreId).second) - return emitError("compute_batch core_id values must be distinct"); + return emitError("compute_batch coreIds values must be distinct"); } Block& block = getBody().front(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 7cb69b0..df0422d 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -1,5 +1,7 @@ #include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" @@ -35,6 +37,7 @@ #include #include "DCPGraph/DCPAnalysis.hpp" +#include "RegularOpCompaction.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" @@ -147,7 +150,7 @@ static SmallVector getMaterializedBatchCoreIds(size_t startCpu, size_t } static SmallVector getBatchCoreIds(Operation* op, size_t laneCount) { - if (auto coreIdsAttr = op->getAttrOfType(onnx_mlir::kCoreIdAttrName)) + if (auto coreIdsAttr = op->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); if (auto coreIdAttr = op->getAttrOfType(onnx_mlir::kCoreIdAttrName)) return SmallVector(laneCount, static_cast(coreIdAttr.getInt())); @@ -304,7 +307,7 @@ static void sinkChannelsIntoBatchComputes(func::FuncOp funcOp, SmallVector coreIds = getBatchCoreIds(batch, static_cast(batch.getLaneCount())); if (!coreIds.empty()) - newBatch->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); auto* newBlock = rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef {}); @@ -548,141 +551,6 @@ void sinkChannelsIntoComputes(func::FuncOp funcOp, int64_t& nextChannelId) { } } -static void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { - IRRewriter rewriter(funcOp.getContext()); - - for (auto compute : funcOp.getOps()) { - Block& block = compute.getBody().front(); - for (auto it = block.begin(); it != block.end();) { - auto receiveOp = dyn_cast(&*it); - if (receiveOp) { - SmallVector run; - Type outputType = receiveOp.getOutput().getType(); - auto runIt = it; - while (runIt != block.end()) { - auto current = dyn_cast(&*runIt); - if (!current || current.getOutput().getType() != outputType) - break; - run.push_back(current); - ++runIt; - } - - if (run.size() > 1) { - struct ReceiveEntry { - spatial::SpatChannelReceiveOp op; - size_t originalIndex = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - uint64_t channelId = 0; - }; - SmallVector sortedEntries; - sortedEntries.reserve(run.size()); - for (auto [originalIndex, op] : llvm::enumerate(run)) - sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); - llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) { - return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) - < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); - }); - - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - SmallVector outputTypes; - channelIds.reserve(sortedEntries.size()); - sourceCoreIds.reserve(sortedEntries.size()); - targetCoreIds.reserve(sortedEntries.size()); - outputTypes.reserve(sortedEntries.size()); - for (ReceiveEntry& entry : sortedEntries) { - (void) entry; - channelIds.push_back(nextChannelId++); - sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); - targetCoreIds.push_back(static_cast(entry.targetCoreId)); - outputTypes.push_back(entry.op.getOutput().getType()); - } - - rewriter.setInsertionPoint(run.front()); - auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter, - run.front().getLoc(), - TypeRange(outputTypes), - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds)); - for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries)) - entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex)); - for (auto op : run) - rewriter.eraseOp(op); - - it = compactReceive->getIterator(); - ++it; - continue; - } - } - - auto sendOp = dyn_cast(&*it); - if (sendOp) { - SmallVector run; - Type inputType = sendOp.getInput().getType(); - auto runIt = it; - while (runIt != block.end()) { - auto current = dyn_cast(&*runIt); - if (!current || current.getInput().getType() != inputType) - break; - run.push_back(current); - ++runIt; - } - - if (run.size() > 1) { - struct SendEntry { - spatial::SpatChannelSendOp op; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - uint64_t channelId = 0; - }; - SmallVector sortedEntries; - sortedEntries.reserve(run.size()); - for (auto op : run) - sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); - llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) { - return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) - < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); - }); - - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - SmallVector inputs; - channelIds.reserve(sortedEntries.size()); - sourceCoreIds.reserve(sortedEntries.size()); - targetCoreIds.reserve(sortedEntries.size()); - inputs.reserve(sortedEntries.size()); - for (SendEntry& entry : sortedEntries) { - (void) entry; - channelIds.push_back(nextChannelId++); - sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); - targetCoreIds.push_back(static_cast(entry.targetCoreId)); - inputs.push_back(entry.op.getInput()); - } - - rewriter.setInsertionPoint(run.front()); - spatial::SpatChannelSendManyOp::create(rewriter, - run.front().getLoc(), - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds), - ValueRange(inputs)); - for (auto op : run) - rewriter.eraseOp(op); - - it = runIt; - continue; - } - } - - ++it; - } - } -} - void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { IRRewriter rewriter(funcOp.getContext()); SmallVector computes(funcOp.getOps()); @@ -755,7 +623,7 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { rebatched.getProperties().setOperandSegmentSizes( {static_cast(weights.size()), static_cast(inputs.size())}); if (haveAllCoreIds) - rebatched->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); SmallVector blockArgTypes; SmallVector blockArgLocs; @@ -1879,6 +1747,9 @@ public: rebatchEquivalentComputes(func, nextChannelId); compactScalarChannelRuns(func, nextChannelId); + compactBatchChannelRuns(func); + compactRegularOpRuns(func); + compactRowWiseWvmmRuns(func); if (!sortTopologically(&func.getBody().front())) { func.emitOpError("failed to topologically order merged Spatial IR"); signalPassFailure(); @@ -2049,7 +1920,7 @@ private: rewriter.getI32IntegerAttr(static_cast(laneCount)), ValueRange(weights), ValueRange(inputs)); - rebatched->setAttr(onnx_mlir::kCoreIdAttrName, + rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(getMaterializedBatchCoreIds(currentCpu, laneCount))); SmallVector blockArgTypes; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp new file mode 100644 index 0000000..bd50aa2 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp @@ -0,0 +1,577 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +#include + +#include "RegularOpCompaction.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +enum class RegularStepKind { + Wvmm, + VAddLhs, + VAddRhs, +}; + +struct RegularStep { + RegularStepKind kind; + int32_t weightIndex = 0; + Value invariantOperand; + Type resultType; +}; + +struct RegularChunk { + Operation* startOp = nullptr; + SmallVector ops; + SmallVector steps; + Value input; + Value output; +}; + +static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) { + return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand + && lhs.resultType == rhs.resultType; +} + +static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChunk& rhs) { + if (lhs.input.getType() != rhs.input.getType() || lhs.output.getType() != rhs.output.getType() + || lhs.steps.size() != rhs.steps.size()) { + return false; + } + + return llvm::all_of(llvm::zip_equal(lhs.steps, rhs.steps), + [](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); }); +} + +static FailureOr analyzeRegularChunk(spatial::SpatWeightedVMMOp startOp) { + RegularChunk chunk; + chunk.startOp = startOp.getOperation(); + chunk.input = startOp.getInput(); + chunk.output = startOp.getOutput(); + chunk.ops.push_back(startOp.getOperation()); + chunk.steps.push_back( + {RegularStepKind::Wvmm, static_cast(startOp.getWeightIndex()), Value(), startOp.getOutput().getType()}); + + Value currentValue = startOp.getOutput(); + while (currentValue.hasOneUse()) { + Operation* user = *currentValue.getUsers().begin(); + if (user->getBlock() != startOp->getBlock()) + break; + + auto vaddOp = dyn_cast(user); + if (!vaddOp) + break; + + if (vaddOp.getLhs() == currentValue) + chunk.steps.push_back({RegularStepKind::VAddLhs, 0, vaddOp.getRhs(), vaddOp.getOutput().getType()}); + else if (vaddOp.getRhs() == currentValue) + chunk.steps.push_back({RegularStepKind::VAddRhs, 0, vaddOp.getLhs(), vaddOp.getOutput().getType()}); + else + break; + + chunk.ops.push_back(vaddOp); + chunk.output = vaddOp.getOutput(); + currentValue = vaddOp.getOutput(); + } + + return chunk; +} + +static void buildRegularMapBody(spatial::SpatMapOp mapOp, const RegularChunk& anchorChunk, IRRewriter& rewriter) { + auto* block = rewriter.createBlock( + &mapOp.getBody(), mapOp.getBody().end(), TypeRange {anchorChunk.input.getType()}, {anchorChunk.startOp->getLoc()}); + rewriter.setInsertionPointToEnd(block); + + IRMapping mapping; + mapping.map(anchorChunk.input, block->getArgument(0)); + + for (Operation* op : anchorChunk.ops) { + Operation* cloned = rewriter.clone(*op, mapping); + for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults())) + mapping.map(oldResult, newResult); + } + + spatial::SpatYieldOp::create( + rewriter, anchorChunk.startOp->getLoc(), ValueRange {mapping.lookup(anchorChunk.output)}); +} + +static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef run) { + assert(!run.empty() && "expected a non-empty regular chunk run"); + const RegularChunk& anchorChunk = run.front(); + + SmallVector inputs; + SmallVector outputTypes; + inputs.reserve(run.size()); + outputTypes.reserve(run.size()); + for (const RegularChunk& chunk : run) { + inputs.push_back(chunk.input); + outputTypes.push_back(chunk.output.getType()); + } + + rewriter.setInsertionPoint(anchorChunk.startOp); + auto mapOp = + spatial::SpatMapOp::create(rewriter, anchorChunk.startOp->getLoc(), TypeRange(outputTypes), ValueRange(inputs)); + buildRegularMapBody(mapOp, anchorChunk, rewriter); + + for (auto [index, chunk] : llvm::enumerate(run)) { + Value output = chunk.output; + output.replaceAllUsesWith(mapOp.getResult(index)); + } + + SmallVector opsToErase; + for (const RegularChunk& chunk : run) + llvm::append_range(opsToErase, chunk.ops); + for (Operation* op : llvm::reverse(opsToErase)) + rewriter.eraseOp(op); +} + +} // namespace + +void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { + IRRewriter rewriter(funcOp.getContext()); + + for (auto compute : funcOp.getOps()) { + Block& block = compute.getBody().front(); + for (auto it = block.begin(); it != block.end();) { + auto receiveOp = dyn_cast(&*it); + if (receiveOp) { + SmallVector run; + Type outputType = receiveOp.getOutput().getType(); + auto runIt = it; + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getOutput().getType() != outputType) + break; + run.push_back(current); + ++runIt; + } + + if (run.size() > 1) { + struct ReceiveEntry { + spatial::SpatChannelReceiveOp op; + size_t originalIndex = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + uint64_t channelId = 0; + }; + SmallVector sortedEntries; + sortedEntries.reserve(run.size()); + for (auto [originalIndex, op] : llvm::enumerate(run)) + sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); + llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) { + return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) + < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); + }); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + SmallVector outputTypes; + channelIds.reserve(sortedEntries.size()); + sourceCoreIds.reserve(sortedEntries.size()); + targetCoreIds.reserve(sortedEntries.size()); + outputTypes.reserve(sortedEntries.size()); + for (ReceiveEntry& entry : sortedEntries) { + (void) entry; + channelIds.push_back(nextChannelId++); + sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); + targetCoreIds.push_back(static_cast(entry.targetCoreId)); + outputTypes.push_back(entry.op.getOutput().getType()); + } + + rewriter.setInsertionPoint(run.front()); + auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter, + run.front().getLoc(), + TypeRange(outputTypes), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds)); + for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries)) + entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex)); + for (auto op : run) + rewriter.eraseOp(op); + + it = compactReceive->getIterator(); + ++it; + continue; + } + } + + auto sendOp = dyn_cast(&*it); + if (sendOp) { + SmallVector run; + Type inputType = sendOp.getInput().getType(); + auto runIt = it; + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getInput().getType() != inputType) + break; + run.push_back(current); + ++runIt; + } + + if (run.size() > 1) { + struct SendEntry { + spatial::SpatChannelSendOp op; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + uint64_t channelId = 0; + }; + SmallVector sortedEntries; + sortedEntries.reserve(run.size()); + for (auto op : run) + sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); + llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) { + return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) + < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); + }); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + SmallVector inputs; + channelIds.reserve(sortedEntries.size()); + sourceCoreIds.reserve(sortedEntries.size()); + targetCoreIds.reserve(sortedEntries.size()); + inputs.reserve(sortedEntries.size()); + for (SendEntry& entry : sortedEntries) { + (void) entry; + channelIds.push_back(nextChannelId++); + sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); + targetCoreIds.push_back(static_cast(entry.targetCoreId)); + inputs.push_back(entry.op.getInput()); + } + + rewriter.setInsertionPoint(run.front()); + spatial::SpatChannelSendManyOp::create(rewriter, + run.front().getLoc(), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds), + ValueRange(inputs)); + for (auto op : run) + rewriter.eraseOp(op); + + it = runIt; + continue; + } + } + + ++it; + } + } +} + +void compactBatchChannelRuns(func::FuncOp funcOp) { + IRRewriter rewriter(funcOp.getContext()); + + for (auto batch : funcOp.getOps()) { + Block& block = batch.getBody().front(); + for (auto it = block.begin(); it != block.end();) { + auto receiveOp = dyn_cast(&*it); + if (receiveOp) { + SmallVector run; + Type outputType = receiveOp.getOutput().getType(); + auto runIt = it; + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getOutput().getType() != outputType) + break; + run.push_back(current); + ++runIt; + } + + if (run.size() > 1) { + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + SmallVector outputTypes; + outputTypes.reserve(run.size()); + for (auto op : run) { + llvm::append_range(channelIds, op.getChannelIds()); + llvm::append_range(sourceCoreIds, op.getSourceCoreIds()); + llvm::append_range(targetCoreIds, op.getTargetCoreIds()); + outputTypes.push_back(op.getOutput().getType()); + } + + rewriter.setInsertionPoint(run.front()); + auto compactReceive = + spatial::SpatChannelReceiveManyBatchOp::create(rewriter, + run.front().getLoc(), + TypeRange(outputTypes), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds)); + for (auto [index, op] : llvm::enumerate(run)) + op.getOutput().replaceAllUsesWith(compactReceive.getResult(index)); + for (auto op : run) + rewriter.eraseOp(op); + + it = compactReceive->getIterator(); + ++it; + continue; + } + } + + auto sendOp = dyn_cast(&*it); + if (sendOp) { + SmallVector run; + Type inputType = sendOp.getInput().getType(); + auto runIt = it; + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getInput().getType() != inputType) + break; + run.push_back(current); + ++runIt; + } + + if (run.size() > 1) { + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + SmallVector inputs; + inputs.reserve(run.size()); + for (auto op : run) { + llvm::append_range(channelIds, op.getChannelIds()); + llvm::append_range(sourceCoreIds, op.getSourceCoreIds()); + llvm::append_range(targetCoreIds, op.getTargetCoreIds()); + inputs.push_back(op.getInput()); + } + + rewriter.setInsertionPoint(run.front()); + spatial::SpatChannelSendManyBatchOp::create(rewriter, + run.front().getLoc(), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds), + ValueRange(inputs)); + for (auto op : run) + rewriter.eraseOp(op); + + it = runIt; + continue; + } + } + + ++it; + } + } +} + +void compactRegularOpRuns(func::FuncOp funcOp) { + IRRewriter rewriter(funcOp.getContext()); + + auto compactInBlock = [&](Block& block) { + for (auto it = block.begin(); it != block.end();) { + auto startOp = dyn_cast(&*it); + if (!startOp) { + ++it; + continue; + } + + auto anchorChunk = analyzeRegularChunk(startOp); + if (failed(anchorChunk)) { + ++it; + continue; + } + + SmallVector run {*anchorChunk}; + auto runIt = std::next(it, static_cast(anchorChunk->ops.size())); + while (runIt != block.end()) { + auto candidateStart = dyn_cast(&*runIt); + if (!candidateStart) + break; + + auto candidateChunk = analyzeRegularChunk(candidateStart); + if (failed(candidateChunk) || !areEquivalentRegularChunks(*anchorChunk, *candidateChunk)) + break; + + run.push_back(*candidateChunk); + runIt = std::next(runIt, static_cast(candidateChunk->ops.size())); + } + + if (run.size() <= 1) { + ++it; + continue; + } + + compactRegularChunkRun(rewriter, run); + it = runIt; + } + }; + + for (auto compute : funcOp.getOps()) + compactInBlock(compute.getBody().front()); + for (auto batch : funcOp.getOps()) + compactInBlock(batch.getBody().front()); +} + +void compactRowWiseWvmmRuns(func::FuncOp funcOp) { + IRRewriter rewriter(funcOp.getContext()); + + for (auto compute : funcOp.getOps()) { + Block& block = compute.getBody().front(); + for (auto it = block.begin(); it != block.end();) { + auto wvmmOp = dyn_cast(&*it); + if (!wvmmOp) { + ++it; + continue; + } + + auto extractRowsOp = wvmmOp.getInput().getDefiningOp(); + auto rowResult = dyn_cast(wvmmOp.getInput()); + auto outputType = dyn_cast(wvmmOp.getOutput().getType()); + if (!extractRowsOp || !rowResult || rowResult.getOwner() != extractRowsOp || !outputType + || !outputType.hasStaticShape() || outputType.getRank() != 2 || outputType.getShape()[0] != 1) { + ++it; + continue; + } + + SmallVector run; + auto runIt = it; + int64_t expectedRow = static_cast(rowResult.getResultNumber()); + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex() + || current.getInput().getDefiningOp() != extractRowsOp + || current.getInput().getType() != wvmmOp.getInput().getType() + || current.getOutput().getType() != wvmmOp.getOutput().getType()) { + break; + } + + auto currentRow = dyn_cast(current.getInput()); + if (!currentRow || currentRow.getResultNumber() != static_cast(expectedRow)) + break; + + run.push_back(current); + ++expectedRow; + ++runIt; + } + + if (run.size() <= 1) { + ++it; + continue; + } + + if (!run.front().getOutput().hasOneUse()) { + ++it; + continue; + } + auto concatUse = run.front().getOutput().getUses().begin(); + auto concatOp = dyn_cast(concatUse->getOwner()); + if (!concatOp) { + ++it; + continue; + } + + unsigned concatStartIndex = concatUse->getOperandNumber(); + bool validConcatRun = true; + for (auto [index, op] : llvm::enumerate(run)) { + if (!op.getOutput().hasOneUse()) { + validConcatRun = false; + break; + } + OpOperand& use = *op.getOutput().getUses().begin(); + if (use.getOwner() != concatOp || use.getOperandNumber() != concatStartIndex + index) { + validConcatRun = false; + break; + } + } + if (!validConcatRun) { + ++it; + continue; + } + + auto inputType = dyn_cast(wvmmOp.getInput().getType()); + auto sourceType = dyn_cast(extractRowsOp.getInput().getType()); + if (!inputType || !sourceType || !inputType.hasStaticShape() || !sourceType.hasStaticShape()) { + ++it; + continue; + } + + int64_t inputCols = inputType.getShape()[1]; + int64_t outputCols = outputType.getShape()[1]; + if (ShapedType::isDynamic(inputCols) || ShapedType::isDynamic(outputCols)) { + ++it; + continue; + } + + int64_t firstRow = static_cast(rowResult.getResultNumber()); + int64_t runLength = static_cast(run.size()); + auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType()); + + rewriter.setInsertionPoint(run.front()); + auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0); + auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength); + auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1); + auto packedInit = + tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType()); + auto loop = + scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()}); + + { + OpBuilder::InsertionGuard guard(rewriter); + Block* loopBlock = loop.getBody(); + rewriter.setInsertionPointToStart(loopBlock); + Value iv = loopBlock->getArgument(0); + Value acc = loopBlock->getArgument(1); + + Value sourceRow = iv; + if (firstRow != 0) { + auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow); + sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue); + } + + SmallVector extractOffsets = {sourceRow, rewriter.getIndexAttr(0)}; + SmallVector extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)}; + SmallVector extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + auto extractedRow = tensor::ExtractSliceOp::create(rewriter, + run.front().getLoc(), + inputType, + extractRowsOp.getInput(), + extractOffsets, + extractSizes, + extractStrides); + auto loopWvmm = spatial::SpatWeightedVMMOp::create( + rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult()); + + SmallVector insertOffsets = {iv, rewriter.getIndexAttr(0)}; + SmallVector insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)}; + SmallVector insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + auto inserted = tensor::InsertSliceOp::create( + rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides); + scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult()); + } + + SmallVector newConcatInputs; + newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1); + for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) { + if (operandIndex == concatStartIndex) + newConcatInputs.push_back(loop.getResult(0)); + if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size()) + newConcatInputs.push_back(operand); + } + rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); }); + for (auto op : run) + rewriter.eraseOp(op); + + it = loop->getIterator(); + ++it; + } + } +} + +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp new file mode 100644 index 0000000..08b7d1e --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" + +#include + +namespace onnx_mlir { + +void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId); +void compactBatchChannelRuns(mlir::func::FuncOp funcOp); +void compactRegularOpRuns(mlir::func::FuncOp funcOp); +void compactRowWiseWvmmRuns(mlir::func::FuncOp funcOp); + +} // namespace onnx_mlir