better spatial IR compaction with better custom syntax, scf.for and
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled

spat.map
This commit is contained in:
NiccoloN
2026-05-06 12:21:58 +02:00
parent 285773fa55
commit b2dc9c38b6
12 changed files with 1442 additions and 274 deletions

View File

@@ -22,6 +22,7 @@
namespace onnx_mlir {
inline constexpr llvm::StringLiteral kCoreIdAttrName = "core_id";
inline constexpr llvm::StringLiteral kCoreIdAttrName = "coreId";
inline constexpr llvm::StringLiteral kCoreIdsAttrName = "coreIds";
} // namespace onnx_mlir

View File

@@ -517,8 +517,8 @@ static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
}
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName);
assert(coreIdsAttr && "pim.core_batch requires core_id array attribute");
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}

View File

@@ -111,7 +111,7 @@ static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t&
}
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName))
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
SmallVector<int32_t> coreIds;
@@ -178,6 +178,43 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan
rewriter.replaceOp(receiveManyOp, ValueRange(replacements));
}
static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp,
int32_t laneCount,
IRMapping& mapper,
IRRewriter& rewriter) {
auto targetCoreIds = sendManyBatchOp.getTargetCoreIds();
for (auto [valueIndex, input] : llvm::enumerate(sendManyBatchOp.getInputs())) {
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
auto targetSlice = targetCoreIds.slice(metadataOffset, laneCount);
pim::PimSendBatchOp::create(rewriter,
sendManyBatchOp.getLoc(),
mapper.lookup(input),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(input)),
rewriter.getDenseI32ArrayAttr(targetSlice));
}
}
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
int32_t laneCount,
IRMapping& mapper,
IRRewriter& rewriter) {
auto sourceCoreIds = receiveManyBatchOp.getSourceCoreIds();
for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) {
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
auto sourceSlice = sourceCoreIds.slice(metadataOffset, laneCount);
auto outputType = cast<ShapedType>(output.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
receiveManyBatchOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, output),
rewriter.getDenseI32ArrayAttr(sourceSlice))
.getOutput();
mapper.map(output, received);
}
}
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
Value input = extractRowsOp.getInput();
RankedTensorType inputType;
@@ -226,6 +263,56 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
rewriter.replaceOp(concatOp, concatenated);
}
static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatWeightedVMMOp> wvmmOps;
funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) {
if (wvmmOp->getParentOfType<pim::PimCoreOp>() || wvmmOp->getParentOfType<pim::PimCoreBatchOp>())
wvmmOps.push_back(wvmmOp);
});
for (auto wvmmOp : wvmmOps) {
rewriter.setInsertionPoint(wvmmOp);
auto outputType = cast<ShapedType>(wvmmOp.getOutput().getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult();
rewriter.replaceOpWithNewOp<pim::PimVMMOp>(wvmmOp,
wvmmOp.getOutput().getType(),
rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()),
wvmmOp.getInput(),
outputBuffer);
}
}
static void expandMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatMapOp> mapOps;
funcOp.walk([&](spatial::SpatMapOp mapOp) { mapOps.push_back(mapOp); });
for (auto mapOp : mapOps) {
Block& body = mapOp.getBody().front();
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
SmallVector<Value> replacements;
replacements.reserve(mapOp.getInputs().size());
rewriter.setInsertionPoint(mapOp);
for (Value input : mapOp.getInputs()) {
IRMapping mapping;
mapping.map(body.getArgument(0), input);
Value replacement = input;
for (Operation& op : body.without_terminator()) {
Operation* cloned = rewriter.clone(op, mapping);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapping.map(originalResult, clonedResult);
rewriter.setInsertionPointAfter(cloned);
}
replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
replacements.push_back(replacement);
}
rewriter.replaceOp(mapOp, replacements);
}
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) {
@@ -551,6 +638,7 @@ void SpatialToPimPass::runOnOperation() {
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
expandMapOps(funcOp, rewriter);
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect,
@@ -640,6 +728,32 @@ void SpatialToPimPass::runOnOperation() {
for (auto extractRowsOp : extractRowsOps)
lowerExtractRows(extractRowsOp, rewriter);
{
RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns);
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
}
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) {
if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
}
}
lowerRemainingSpatialMathOps(funcOp, rewriter);
RewritePatternSet channelPatterns(ctx);
populateWithGenerated(channelPatterns);
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
@@ -939,7 +1053,7 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
ValueRange(batchInputs));
coreBatchOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
@@ -1000,6 +1114,11 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
continue;
}
if (auto sendManyBatchOp = dyn_cast<spatial::SpatChannelSendManyBatchOp>(op)) {
lowerChannelSendManyBatch(sendManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter);
continue;
}
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
@@ -1014,6 +1133,11 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
continue;
}
if (auto receiveManyBatchOp = dyn_cast<spatial::SpatChannelReceiveManyBatchOp>(op)) {
lowerChannelReceiveManyBatch(receiveManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter);
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper);

View File

@@ -39,7 +39,7 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
}];
}
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments]> {
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSizedOperandSegments]> {
let summary = "Execute equivalent batched core bodies";
let regions = (region SizedRegion<1>:$body);

View File

@@ -257,8 +257,8 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
auto newOp = PimCoreBatchOp::create(
rewriter, coreBatchOp.getLoc(), coreBatchOp.getLaneCountAttr(), ValueRange(weights), ValueRange(inputs));
newOp.getProperties().setOperandSegmentSizes({static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdAttrName))
newOp->setAttr(onnx_mlir::kCoreIdAttrName, coreIdsAttr);
if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdsAttrName))
newOp->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr);
rewriter.inlineRegionBefore(coreBatchOp.getBody(), newOp.getBody(), newOp.getBody().begin());
for (Block& block : newOp.getBody())

View File

@@ -8,6 +8,7 @@ add_pim_library(SpatialOps
SpatialOpsVerify.cpp
SpatialOpsCanonicalization.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp

View File

@@ -102,6 +102,23 @@ def SpatConcatOp : SpatOp<"concat", []> {
let hasCustomAssemblyFormat = 1;
}
def SpatMapOp : SpatOp<"map", [SingleBlock]> {
let summary = "Apply the same lane-local region to many independent tensors";
let arguments = (ins
Variadic<SpatTensor>:$inputs
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
@@ -184,6 +201,20 @@ def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
let hasCustomAssemblyFormat = 1;
}
def SpatChannelSendManyBatchOp : SpatOp<"channel_send_many_batch", []> {
let summary = "Send multiple per-lane tensors through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
Variadic<SpatTensor>:$inputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
let summary = "Receive a per-lane tensor through logical channels in a batch body";
@@ -201,11 +232,28 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
let hasCustomAssemblyFormat = 1;
}
def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> {
let summary = "Receive multiple per-lane tensors through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
def SpatWeightedVMMOp : SpatOp<"wvmm", []> {
let summary = "Vector-matrix multiplication within a weighted compute operation";
let arguments = (ins

View File

@@ -42,6 +42,10 @@ static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter)
printer << (delimiter == ListDelimiter::Square ? "]" : ")");
}
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
}
template <typename EntryT, typename ParseEntryFn>
static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
ListDelimiter delimiter,
@@ -75,13 +79,26 @@ static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
}
template <typename IntT>
static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
if (parser.parseLSquare())
return failure();
if (succeeded(parser.parseOptionalRSquare()))
static ParseResult
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<IntT> subgroup;
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(values, subgroup);
}
else {
int64_t first = 0;
if (parser.parseInteger(first))
return failure();
@@ -118,8 +135,9 @@ static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorIm
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(first));
}
}
if (succeeded(parser.parseOptionalRSquare()))
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
break;
if (parser.parseComma())
return failure();
@@ -128,6 +146,14 @@ static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorIm
return success();
}
template <typename IntT>
static ParseResult
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
return parseCompressedIntegerEntries(parser, delimiter, values);
}
template <typename RangeT, typename PrintEntryFn>
static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
for (size_t index = 0; index < entries.size();) {
@@ -146,35 +172,51 @@ static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, Prin
}
template <typename IntT>
static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
printer << "[";
for (size_t index = 0; index < values.size();) {
if (index != 0)
printer << ", ";
auto findEqualRunEnd = [&](size_t start) {
size_t end = start + 1;
while (end < values.size() && values[end] == values[start])
++end;
return end;
static void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef<IntT> values, ListDelimiter delimiter) {
struct FlatCompression {
enum class Kind {
Single,
EqualRun,
Progression
};
size_t firstRunEnd = findEqualRunEnd(index);
size_t repeatCount = firstRunEnd - index;
Kind kind = Kind::Single;
size_t covered = 1;
size_t repeatCount = 1;
size_t progressionValueCount = 1;
int64_t step = 1;
IntT firstValue {};
IntT lastValue {};
};
auto computeFlatCompression = [&](size_t start) {
FlatCompression compression;
compression.firstValue = values[start];
compression.lastValue = values[start];
auto findEqualRunEnd = [&](size_t runStart) {
size_t runEnd = runStart + 1;
while (runEnd < values.size() && values[runEnd] == values[runStart])
++runEnd;
return runEnd;
};
size_t firstRunEnd = findEqualRunEnd(start);
compression.repeatCount = firstRunEnd - start;
size_t progressionEnd = firstRunEnd;
int64_t step = 0;
IntT lastValue = values[index];
IntT lastValue = values[start];
if (firstRunEnd < values.size()) {
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[index]);
if (step > 0 && secondRunEnd - firstRunEnd == repeatCount) {
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
progressionEnd = secondRunEnd;
lastValue = values[firstRunEnd];
size_t currentRunStart = secondRunEnd;
while (currentRunStart < values.size()) {
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
if (currentRunEnd - currentRunStart != repeatCount)
if (currentRunEnd - currentRunStart != compression.repeatCount)
break;
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
break;
@@ -188,27 +230,99 @@ static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> val
}
}
size_t progressionValueCount = repeatCount == 0 ? 0 : (progressionEnd - index) / repeatCount;
if (progressionEnd > firstRunEnd && progressionValueCount >= 3) {
printer << values[index] << " to " << lastValue;
if (step != 1)
printer << " by " << step;
if (repeatCount > 1)
printer << " x" << repeatCount;
index = progressionEnd;
continue;
compression.covered = 1;
if (progressionEnd > firstRunEnd) {
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
if (progressionValueCount >= 3) {
compression.kind = FlatCompression::Kind::Progression;
compression.covered = progressionEnd - start;
compression.progressionValueCount = progressionValueCount;
compression.step = step;
compression.lastValue = lastValue;
return compression;
}
}
if (repeatCount > 1) {
printer << values[index] << " x" << repeatCount;
index = firstRunEnd;
continue;
if (compression.repeatCount > 1) {
compression.kind = FlatCompression::Kind::EqualRun;
compression.covered = compression.repeatCount;
return compression;
}
printer << values[index];
index = firstRunEnd;
return compression;
};
auto findRepeatedSublist = [&](size_t start) {
size_t bestLength = 0;
size_t bestRepeatCount = 1;
size_t remaining = values.size() - start;
for (size_t length = 2; length * 2 <= remaining; ++length) {
size_t repeatCount = 1;
ArrayRef<IntT> candidate = values.slice(start, length);
while (start + (repeatCount + 1) * length <= values.size()
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
++repeatCount;
}
printer << "]";
if (repeatCount <= 1)
continue;
size_t covered = length * repeatCount;
size_t bestCovered = bestLength * bestRepeatCount;
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
bestLength = length;
bestRepeatCount = repeatCount;
}
}
return std::pair(bestLength, bestRepeatCount);
};
printOpenDelimiter(printer, delimiter);
for (size_t index = 0; index < values.size();) {
if (index != 0)
printer << ", ";
FlatCompression flat = computeFlatCompression(index);
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren);
printer << " x" << sublistRepeatCount;
index += repeatedSublistCoverage;
continue;
}
switch (flat.kind) {
case FlatCompression::Kind::Progression:
printer << flat.firstValue << " to " << flat.lastValue;
if (flat.step != 1)
printer << " by " << flat.step;
if (flat.repeatCount > 1)
printer << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::EqualRun:
printer << flat.firstValue << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::Single:
printer << flat.firstValue;
index += flat.covered;
break;
}
}
printCloseDelimiter(printer, delimiter);
}
template <typename IntT>
static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
}
template <typename IntT>
static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
}
static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
@@ -267,6 +381,165 @@ static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, List
printCloseDelimiter(printer, delimiter);
}
static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands);
static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands);
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty);
static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
return false;
SmallVector<Value> valueVec(values.begin(), values.end());
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
return false;
return true;
}
static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
return false;
SmallVector<Type> typeVec(types.begin(), types.end());
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
return false;
return true;
}
static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) {
printer << "[";
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printOperand(values[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (values.size() / tupleSize) << "]";
}
static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) {
printer << "[";
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printType(types[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (types.size() / tupleSize) << "]";
}
static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parser.parseLSquare())
return failure();
if (succeeded(parser.parseOptionalRSquare()))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleOperands.clear();
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
}
return parser.parseRSquare();
}
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parser.parseOptionalRSquare()))
return success();
if (parser.parseComma())
return failure();
}
}
static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl<Type>& types) {
if (parser.parseLSquare())
return failure();
if (succeeded(parser.parseOptionalRSquare()))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<Type> tupleTypes;
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleTypes.clear();
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
}
return parser.parseRSquare();
}
while (true) {
Type type;
if (parser.parseType(type))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
types.push_back(type);
if (succeeded(parser.parseOptionalRSquare()))
return success();
if (parser.parseComma())
return failure();
}
}
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
OpAsmParser::UnresolvedOperand firstOperand,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
@@ -440,19 +713,88 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value);
}
static void buildImplicitRegionArgs(OpAsmParser& parser,
ArrayRef<Type> inputTypes,
SmallVectorImpl<std::string>& generatedNames,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
generatedNames.reserve(inputTypes.size());
arguments.reserve(inputTypes.size());
for (auto [index, inputType] : llvm::enumerate(inputTypes)) {
generatedNames.push_back("arg" + std::to_string(index + 1));
OpAsmParser::Argument arg;
arg.ssaName = {parser.getCurrentLocation(), generatedNames.back(), 0};
arg.type = inputType;
arguments.push_back(arg);
static void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
if (block.getNumArguments() == 0) {
printer << "() = ()";
return;
}
if (block.getNumArguments() == 1) {
printer.printOperand(block.getArgument(0));
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
return;
}
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
}
static ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
OpAsmParser::Argument firstArgument,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::Argument lastArgument;
if (parser.parseArgument(lastArgument))
return failure();
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
}
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
OpAsmParser::Argument argument;
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
arguments.push_back(argument);
}
return success();
}
arguments.push_back(firstArgument);
return success();
}
static ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument))
return failure();
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
}
static void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
argument.type = inputType;
}
static ParseResult parseArgumentBindings(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalLParen())) {
if (succeeded(parser.parseOptionalRParen())) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
return failure();
return success();
}
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
if (parser.parseRParen() || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
return failure();
return success();
}
OpAsmParser::Argument argument;
if (parser.parseArgument(argument) || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
return failure();
arguments.push_back(argument);
return success();
}
} // namespace
@@ -519,8 +861,8 @@ ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result
void SpatConcatOp::print(OpAsmPrinter& printer) {
printer << " axis " << getAxis();
printer << " args = ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
printer << " ";
printCompressedValueSequence(printer, getInputs());
printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
@@ -537,11 +879,7 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
return failure();
if (succeeded(parser.parseOptionalKeyword("args"))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
return failure();
}
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
if (parseCompressedOperandSequence(parser, inputs)) {
return failure();
}
@@ -563,14 +901,54 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
return success();
}
void SpatMapOp::print(OpAsmPrinter& printer) {
printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getInputs().front().getType());
printer << " -> ";
printer.printType(getOutputs().front().getType());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult SpatMapOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
Type inputType;
Type outputType;
if (parseArgumentBindings(parser, regionArgs, inputs))
return failure();
if (inputs.empty())
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|| parser.parseArrow() || parser.parseType(outputType))
return failure();
SmallVector<Type> inputTypes(inputs.size(), inputType);
SmallVector<Type> outputTypes(inputs.size(), outputType);
if (regionArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
applyArgumentTypes(inputTypes, regionArgs);
Region* body = result.addRegion();
return parser.parseRegion(*body, regionArgs);
}
void SpatCompute::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
printer << " args = ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs());
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " core_id " << coreIdAttr.getInt();
printer << " coreId " << coreIdAttr.getInt();
printer.printOptionalAttrDict((*this)->getAttrs(),
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
@@ -587,7 +965,6 @@ void SpatCompute::print(OpAsmPrinter& printer) {
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<std::string> generatedArgNames;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
@@ -598,15 +975,10 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
return failure();
if (succeeded(parser.parseOptionalKeyword("args"))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
if (parseArgumentBindings(parser, regionArgs, inputs))
return failure();
}
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
return failure();
}
bool hasCoreId = succeeded(parser.parseOptionalKeyword("core_id"));
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
@@ -622,9 +994,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (regionArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(),
"core_id cannot be specified both positionally and in attr-dict");
"coreId cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute(
@@ -639,26 +1013,33 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
result.addTypes(outputTypes);
Region* body = result.addRegion();
buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs);
applyArgumentTypes(inputTypes, regionArgs);
return parser.parseRegion(*body, regionArgs);
}
void SpatComputeBatch::print(OpAsmPrinter& printer) {
printer << " lanes " << getLaneCount() << " ";
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
printValueTupleRun(printer, getWeights(), weightsPerLane);
else
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
printer << " args = ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs());
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName)) {
printer << " core_ids ";
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
printer << " coreIds ";
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
}
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " : ";
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane);
else
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
@@ -671,7 +1052,6 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
int32_t laneCount = 0;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<std::string> generatedArgNames;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
@@ -682,24 +1062,18 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
return failure();
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
if (parseCompressedOrTupleOperandList(parser, weights))
return failure();
if (succeeded(parser.parseOptionalKeyword("args"))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
if (parseArgumentBindings(parser, regionArgs, inputs))
return failure();
}
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
return failure();
}
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("core_ids"));
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedOrTupleTypeList(parser, weightTypes)
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
@@ -709,8 +1083,11 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both in core_ids and attr-dict");
if (regionArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
@@ -718,7 +1095,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreIds)
result.addAttribute(onnx_mlir::kCoreIdAttrName, getDenseI32ArrayAttr(parser, coreIds));
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
@@ -726,7 +1103,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
result.addTypes(outputTypes);
Region* body = result.addRegion();
buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs);
applyArgumentTypes(inputTypes, regionArgs);
return parser.parseRegion(*body, regionArgs);
}
@@ -867,6 +1244,55 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
return parser.resolveOperand(input, inputType, result.operands);
}
void SpatChannelSendManyBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult SpatChannelSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
@@ -908,5 +1334,47 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
return success();
}
void SpatChannelReceiveManyBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getResultTypes());
}
ParseResult SpatChannelReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<Type> outputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputTypes);
return success();
}
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -83,13 +83,13 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
if (auto computeOp = dyn_cast<SpatCompute>(weightedOp->getParentOp()))
if (auto computeOp = weightedOp->getParentOfType<SpatCompute>())
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weightedOp->getParentOp()))
if (auto coreOp = weightedOp->getParentOfType<pim::PimCoreOp>())
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
if (auto batchOp = dyn_cast<SpatComputeBatch>(weightedOp->getParentOp())) {
if (auto batchOp = weightedOp->getParentOfType<SpatComputeBatch>()) {
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
return failure();
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
@@ -144,6 +144,23 @@ static LogicalResult verifyBatchChannelSizes(Operation* op,
return success();
}
static LogicalResult verifyManyBatchChannelSizes(Operation* op,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
size_t valueCount) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.size() != valueCount * static_cast<size_t>(*laneCount))
return op->emitError("channel metadata length must match the number of values times parent laneCount");
return success();
}
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
@@ -306,6 +323,39 @@ LogicalResult SpatConcatOp::verify() {
return success();
}
LogicalResult SpatMapOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");
if (getOutputs().size() != getInputs().size())
return emitError("number of outputs must match number of inputs");
Type inputType = getInputs().front().getType();
for (Value input : getInputs().drop_front())
if (input.getType() != inputType)
return emitError("all inputs must have the same type");
Type outputType = getOutputs().front().getType();
for (Value output : getOutputs().drop_front())
if (output.getType() != outputType)
return emitError("all outputs must have the same type");
Block& block = getBody().front();
if (block.getNumArguments() != 1)
return emitError("body must have exactly one block argument");
if (block.getArgument(0).getType() != inputType)
return emitError("body block argument type must match input type");
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("body must terminate with spat.yield");
if (yieldOp.getNumOperands() != 1)
return emitError("body yield must produce exactly one value");
if (yieldOp.getOperand(0).getType() != outputType)
return emitError("body yield type must match output type");
return success();
}
LogicalResult SpatCompute::verify() {
auto& block = getBody().front();
if (block.mightHaveTerminator()) {
@@ -365,10 +415,24 @@ LogicalResult SpatChannelSendBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
}
LogicalResult SpatChannelSendManyBatchOp::verify() {
if (failed(verifyManyBatchChannelSizes(
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
return failure();
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many_batch");
}
LogicalResult SpatChannelReceiveBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
}
LogicalResult SpatChannelReceiveManyBatchOp::verify() {
if (failed(verifyManyBatchChannelSizes(
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
return failure();
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many_batch");
}
LogicalResult SpatComputeBatch::verify() {
int32_t count = getLaneCount();
if (count <= 0)
@@ -405,18 +469,18 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("all outputs must have the same type");
}
if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdAttrName)) {
if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdsAttrName)) {
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
if (!coreIdsAttr)
return emitError("compute_batch core_id attribute must be a dense i32 array");
return emitError("compute_batch coreIds attribute must be a dense i32 array");
if (coreIdsAttr.size() != laneCountSz)
return emitError("compute_batch core_id array length must match laneCount");
return emitError("compute_batch coreIds array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
return emitError("compute_batch core_id values must be positive");
return emitError("compute_batch coreIds values must be positive");
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
for (int32_t coreId : coreIdsAttr.asArrayRef())
if (!seenCoreIds.insert(coreId).second)
return emitError("compute_batch core_id values must be distinct");
return emitError("compute_batch coreIds values must be distinct");
}
Block& block = getBody().front();

View File

@@ -1,5 +1,7 @@
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
@@ -35,6 +37,7 @@
#include <vector>
#include "DCPGraph/DCPAnalysis.hpp"
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
@@ -147,7 +150,7 @@ static SmallVector<int32_t> getMaterializedBatchCoreIds(size_t startCpu, size_t
}
static SmallVector<int32_t> getBatchCoreIds(Operation* op, size_t laneCount) {
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName))
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
if (auto coreIdAttr = op->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return SmallVector<int32_t>(laneCount, static_cast<int32_t>(coreIdAttr.getInt()));
@@ -304,7 +307,7 @@ static void sinkChannelsIntoBatchComputes(func::FuncOp funcOp,
SmallVector<int32_t> coreIds = getBatchCoreIds(batch, static_cast<size_t>(batch.getLaneCount()));
if (!coreIds.empty())
newBatch->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
auto* newBlock =
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef<Location> {});
@@ -548,141 +551,6 @@ void sinkChannelsIntoComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
}
}
static void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
for (auto compute : funcOp.getOps<SpatCompute>()) {
Block& block = compute.getBody().front();
for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
if (receiveOp) {
SmallVector<spatial::SpatChannelReceiveOp> run;
Type outputType = receiveOp.getOutput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
if (!current || current.getOutput().getType() != outputType)
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
struct ReceiveEntry {
spatial::SpatChannelReceiveOp op;
size_t originalIndex = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
uint64_t channelId = 0;
};
SmallVector<ReceiveEntry> sortedEntries;
sortedEntries.reserve(run.size());
for (auto [originalIndex, op] : llvm::enumerate(run))
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Type> outputTypes;
channelIds.reserve(sortedEntries.size());
sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size());
outputTypes.reserve(sortedEntries.size());
for (ReceiveEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
outputTypes.push_back(entry.op.getOutput().getType());
}
rewriter.setInsertionPoint(run.front());
auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter,
run.front().getLoc(),
TypeRange(outputTypes),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex));
for (auto op : run)
rewriter.eraseOp(op);
it = compactReceive->getIterator();
++it;
continue;
}
}
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
if (sendOp) {
SmallVector<spatial::SpatChannelSendOp> run;
Type inputType = sendOp.getInput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
if (!current || current.getInput().getType() != inputType)
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
struct SendEntry {
spatial::SpatChannelSendOp op;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
uint64_t channelId = 0;
};
SmallVector<SendEntry> sortedEntries;
sortedEntries.reserve(run.size());
for (auto op : run)
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Value> inputs;
channelIds.reserve(sortedEntries.size());
sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size());
inputs.reserve(sortedEntries.size());
for (SendEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
inputs.push_back(entry.op.getInput());
}
rewriter.setInsertionPoint(run.front());
spatial::SpatChannelSendManyOp::create(rewriter,
run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
ValueRange(inputs));
for (auto op : run)
rewriter.eraseOp(op);
it = runIt;
continue;
}
}
++it;
}
}
}
void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
@@ -755,7 +623,7 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
rebatched.getProperties().setOperandSegmentSizes(
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
if (haveAllCoreIds)
rebatched->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
@@ -1879,6 +1747,9 @@ public:
rebatchEquivalentComputes(func, nextChannelId);
compactScalarChannelRuns(func, nextChannelId);
compactBatchChannelRuns(func);
compactRegularOpRuns(func);
compactRowWiseWvmmRuns(func);
if (!sortTopologically(&func.getBody().front())) {
func.emitOpError("failed to topologically order merged Spatial IR");
signalPassFailure();
@@ -2049,7 +1920,7 @@ private:
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
ValueRange(weights),
ValueRange(inputs));
rebatched->setAttr(onnx_mlir::kCoreIdAttrName,
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName,
rewriter.getDenseI32ArrayAttr(getMaterializedBatchCoreIds(currentCpu, laneCount)));
SmallVector<Type> blockArgTypes;

View File

@@ -0,0 +1,577 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <tuple>
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
enum class RegularStepKind {
Wvmm,
VAddLhs,
VAddRhs,
};
struct RegularStep {
RegularStepKind kind;
int32_t weightIndex = 0;
Value invariantOperand;
Type resultType;
};
struct RegularChunk {
Operation* startOp = nullptr;
SmallVector<Operation*> ops;
SmallVector<RegularStep> steps;
Value input;
Value output;
};
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand
&& lhs.resultType == rhs.resultType;
}
static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChunk& rhs) {
if (lhs.input.getType() != rhs.input.getType() || lhs.output.getType() != rhs.output.getType()
|| lhs.steps.size() != rhs.steps.size()) {
return false;
}
return llvm::all_of(llvm::zip_equal(lhs.steps, rhs.steps),
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
}
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatWeightedVMMOp startOp) {
RegularChunk chunk;
chunk.startOp = startOp.getOperation();
chunk.input = startOp.getInput();
chunk.output = startOp.getOutput();
chunk.ops.push_back(startOp.getOperation());
chunk.steps.push_back(
{RegularStepKind::Wvmm, static_cast<int32_t>(startOp.getWeightIndex()), Value(), startOp.getOutput().getType()});
Value currentValue = startOp.getOutput();
while (currentValue.hasOneUse()) {
Operation* user = *currentValue.getUsers().begin();
if (user->getBlock() != startOp->getBlock())
break;
auto vaddOp = dyn_cast<spatial::SpatVAddOp>(user);
if (!vaddOp)
break;
if (vaddOp.getLhs() == currentValue)
chunk.steps.push_back({RegularStepKind::VAddLhs, 0, vaddOp.getRhs(), vaddOp.getOutput().getType()});
else if (vaddOp.getRhs() == currentValue)
chunk.steps.push_back({RegularStepKind::VAddRhs, 0, vaddOp.getLhs(), vaddOp.getOutput().getType()});
else
break;
chunk.ops.push_back(vaddOp);
chunk.output = vaddOp.getOutput();
currentValue = vaddOp.getOutput();
}
return chunk;
}
static void buildRegularMapBody(spatial::SpatMapOp mapOp, const RegularChunk& anchorChunk, IRRewriter& rewriter) {
auto* block = rewriter.createBlock(
&mapOp.getBody(), mapOp.getBody().end(), TypeRange {anchorChunk.input.getType()}, {anchorChunk.startOp->getLoc()});
rewriter.setInsertionPointToEnd(block);
IRMapping mapping;
mapping.map(anchorChunk.input, block->getArgument(0));
for (Operation* op : anchorChunk.ops) {
Operation* cloned = rewriter.clone(*op, mapping);
for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults()))
mapping.map(oldResult, newResult);
}
spatial::SpatYieldOp::create(
rewriter, anchorChunk.startOp->getLoc(), ValueRange {mapping.lookup(anchorChunk.output)});
}
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
assert(!run.empty() && "expected a non-empty regular chunk run");
const RegularChunk& anchorChunk = run.front();
SmallVector<Value> inputs;
SmallVector<Type> outputTypes;
inputs.reserve(run.size());
outputTypes.reserve(run.size());
for (const RegularChunk& chunk : run) {
inputs.push_back(chunk.input);
outputTypes.push_back(chunk.output.getType());
}
rewriter.setInsertionPoint(anchorChunk.startOp);
auto mapOp =
spatial::SpatMapOp::create(rewriter, anchorChunk.startOp->getLoc(), TypeRange(outputTypes), ValueRange(inputs));
buildRegularMapBody(mapOp, anchorChunk, rewriter);
for (auto [index, chunk] : llvm::enumerate(run)) {
Value output = chunk.output;
output.replaceAllUsesWith(mapOp.getResult(index));
}
SmallVector<Operation*> opsToErase;
for (const RegularChunk& chunk : run)
llvm::append_range(opsToErase, chunk.ops);
for (Operation* op : llvm::reverse(opsToErase))
rewriter.eraseOp(op);
}
} // namespace
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
Block& block = compute.getBody().front();
for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
if (receiveOp) {
SmallVector<spatial::SpatChannelReceiveOp> run;
Type outputType = receiveOp.getOutput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
if (!current || current.getOutput().getType() != outputType)
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
struct ReceiveEntry {
spatial::SpatChannelReceiveOp op;
size_t originalIndex = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
uint64_t channelId = 0;
};
SmallVector<ReceiveEntry> sortedEntries;
sortedEntries.reserve(run.size());
for (auto [originalIndex, op] : llvm::enumerate(run))
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Type> outputTypes;
channelIds.reserve(sortedEntries.size());
sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size());
outputTypes.reserve(sortedEntries.size());
for (ReceiveEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
outputTypes.push_back(entry.op.getOutput().getType());
}
rewriter.setInsertionPoint(run.front());
auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter,
run.front().getLoc(),
TypeRange(outputTypes),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex));
for (auto op : run)
rewriter.eraseOp(op);
it = compactReceive->getIterator();
++it;
continue;
}
}
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
if (sendOp) {
SmallVector<spatial::SpatChannelSendOp> run;
Type inputType = sendOp.getInput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
if (!current || current.getInput().getType() != inputType)
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
struct SendEntry {
spatial::SpatChannelSendOp op;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
uint64_t channelId = 0;
};
SmallVector<SendEntry> sortedEntries;
sortedEntries.reserve(run.size());
for (auto op : run)
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Value> inputs;
channelIds.reserve(sortedEntries.size());
sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size());
inputs.reserve(sortedEntries.size());
for (SendEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
inputs.push_back(entry.op.getInput());
}
rewriter.setInsertionPoint(run.front());
spatial::SpatChannelSendManyOp::create(rewriter,
run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
ValueRange(inputs));
for (auto op : run)
rewriter.eraseOp(op);
it = runIt;
continue;
}
}
++it;
}
}
}
void compactBatchChannelRuns(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
for (auto batch : funcOp.getOps<spatial::SpatComputeBatch>()) {
Block& block = batch.getBody().front();
for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
if (receiveOp) {
SmallVector<spatial::SpatChannelReceiveBatchOp> run;
Type outputType = receiveOp.getOutput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*runIt);
if (!current || current.getOutput().getType() != outputType)
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Type> outputTypes;
outputTypes.reserve(run.size());
for (auto op : run) {
llvm::append_range(channelIds, op.getChannelIds());
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
outputTypes.push_back(op.getOutput().getType());
}
rewriter.setInsertionPoint(run.front());
auto compactReceive =
spatial::SpatChannelReceiveManyBatchOp::create(rewriter,
run.front().getLoc(),
TypeRange(outputTypes),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [index, op] : llvm::enumerate(run))
op.getOutput().replaceAllUsesWith(compactReceive.getResult(index));
for (auto op : run)
rewriter.eraseOp(op);
it = compactReceive->getIterator();
++it;
continue;
}
}
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
if (sendOp) {
SmallVector<spatial::SpatChannelSendBatchOp> run;
Type inputType = sendOp.getInput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelSendBatchOp>(&*runIt);
if (!current || current.getInput().getType() != inputType)
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Value> inputs;
inputs.reserve(run.size());
for (auto op : run) {
llvm::append_range(channelIds, op.getChannelIds());
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
inputs.push_back(op.getInput());
}
rewriter.setInsertionPoint(run.front());
spatial::SpatChannelSendManyBatchOp::create(rewriter,
run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
ValueRange(inputs));
for (auto op : run)
rewriter.eraseOp(op);
it = runIt;
continue;
}
}
++it;
}
}
}
void compactRegularOpRuns(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
auto compactInBlock = [&](Block& block) {
for (auto it = block.begin(); it != block.end();) {
auto startOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
if (!startOp) {
++it;
continue;
}
auto anchorChunk = analyzeRegularChunk(startOp);
if (failed(anchorChunk)) {
++it;
continue;
}
SmallVector<RegularChunk> run {*anchorChunk};
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
while (runIt != block.end()) {
auto candidateStart = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
if (!candidateStart)
break;
auto candidateChunk = analyzeRegularChunk(candidateStart);
if (failed(candidateChunk) || !areEquivalentRegularChunks(*anchorChunk, *candidateChunk))
break;
run.push_back(*candidateChunk);
runIt = std::next(runIt, static_cast<std::ptrdiff_t>(candidateChunk->ops.size()));
}
if (run.size() <= 1) {
++it;
continue;
}
compactRegularChunkRun(rewriter, run);
it = runIt;
}
};
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
compactInBlock(compute.getBody().front());
for (auto batch : funcOp.getOps<spatial::SpatComputeBatch>())
compactInBlock(batch.getBody().front());
}
void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
Block& block = compute.getBody().front();
for (auto it = block.begin(); it != block.end();) {
auto wvmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
if (!wvmmOp) {
++it;
continue;
}
auto extractRowsOp = wvmmOp.getInput().getDefiningOp<spatial::SpatExtractRowsOp>();
auto rowResult = dyn_cast<OpResult>(wvmmOp.getInput());
auto outputType = dyn_cast<RankedTensorType>(wvmmOp.getOutput().getType());
if (!extractRowsOp || !rowResult || rowResult.getOwner() != extractRowsOp || !outputType
|| !outputType.hasStaticShape() || outputType.getRank() != 2 || outputType.getShape()[0] != 1) {
++it;
continue;
}
SmallVector<spatial::SpatWeightedVMMOp> run;
auto runIt = it;
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|| current.getInput().getType() != wvmmOp.getInput().getType()
|| current.getOutput().getType() != wvmmOp.getOutput().getType()) {
break;
}
auto currentRow = dyn_cast<OpResult>(current.getInput());
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
break;
run.push_back(current);
++expectedRow;
++runIt;
}
if (run.size() <= 1) {
++it;
continue;
}
if (!run.front().getOutput().hasOneUse()) {
++it;
continue;
}
auto concatUse = run.front().getOutput().getUses().begin();
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
if (!concatOp) {
++it;
continue;
}
unsigned concatStartIndex = concatUse->getOperandNumber();
bool validConcatRun = true;
for (auto [index, op] : llvm::enumerate(run)) {
if (!op.getOutput().hasOneUse()) {
validConcatRun = false;
break;
}
OpOperand& use = *op.getOutput().getUses().begin();
if (use.getOwner() != concatOp || use.getOperandNumber() != concatStartIndex + index) {
validConcatRun = false;
break;
}
}
if (!validConcatRun) {
++it;
continue;
}
auto inputType = dyn_cast<RankedTensorType>(wvmmOp.getInput().getType());
auto sourceType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!inputType || !sourceType || !inputType.hasStaticShape() || !sourceType.hasStaticShape()) {
++it;
continue;
}
int64_t inputCols = inputType.getShape()[1];
int64_t outputCols = outputType.getShape()[1];
if (ShapedType::isDynamic(inputCols) || ShapedType::isDynamic(outputCols)) {
++it;
continue;
}
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
int64_t runLength = static_cast<int64_t>(run.size());
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
rewriter.setInsertionPoint(run.front());
auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0);
auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength);
auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1);
auto packedInit =
tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType());
auto loop =
scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
{
OpBuilder::InsertionGuard guard(rewriter);
Block* loopBlock = loop.getBody();
rewriter.setInsertionPointToStart(loopBlock);
Value iv = loopBlock->getArgument(0);
Value acc = loopBlock->getArgument(1);
Value sourceRow = iv;
if (firstRow != 0) {
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow);
sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue);
}
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
run.front().getLoc(),
inputType,
extractRowsOp.getInput(),
extractOffsets,
extractSizes,
extractStrides);
auto loopWvmm = spatial::SpatWeightedVMMOp::create(
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto inserted = tensor::InsertSliceOp::create(
rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult());
}
SmallVector<Value> newConcatInputs;
newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1);
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
if (operandIndex == concatStartIndex)
newConcatInputs.push_back(loop.getResult(0));
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size())
newConcatInputs.push_back(operand);
}
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
for (auto op : run)
rewriter.eraseOp(op);
it = loop->getIterator();
++it;
}
}
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,14 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include <cstdint>
namespace onnx_mlir {
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
void compactRowWiseWvmmRuns(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir