compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
better IR compaction after dcp merge remove pim.mvm op better memory report
This commit is contained in:
@@ -394,30 +394,6 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMVMOp : PimOp<"mvm", [DestinationStyleOpInterface]> {
|
||||
let summary = "Matrix-vector multiplication: c = a * b";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> {
|
||||
let summary = "Element-wise addition: c = a + b";
|
||||
|
||||
|
||||
@@ -538,33 +538,6 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
||||
}
|
||||
};
|
||||
|
||||
struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface, PimMVMOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto mvmOp = cast<PimMVMOp>(op);
|
||||
|
||||
auto inputOpt = getBufferOrValue(rewriter, mvmOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, mvmOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
@@ -655,7 +628,6 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||
PimMVMOp::attachInterface<MVMOpInterface>(*ctx);
|
||||
|
||||
PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx);
|
||||
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
|
||||
|
||||
@@ -150,10 +150,7 @@ def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input)
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||
@@ -170,10 +167,7 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` type($output)
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
||||
@@ -201,10 +195,7 @@ def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input)
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
@@ -238,10 +229,7 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []>
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` type($output)
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -47,6 +47,95 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
template <typename TensorSendOpTy>
|
||||
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||
printer << " ";
|
||||
printer.printOperand(op.getInput());
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getInput().getType());
|
||||
}
|
||||
|
||||
template <typename TensorReceiveOpTy>
|
||||
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getOutput().getType());
|
||||
}
|
||||
|
||||
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) {
|
||||
Type outputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
||||
@@ -316,6 +405,12 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorSendOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
@@ -362,6 +457,18 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorSendOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorReceiveOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
@@ -403,5 +510,11 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorReceiveOp(parser, result);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1436,6 +1436,21 @@ public:
|
||||
compactBatchChannelRuns(func);
|
||||
compactRegularOpRuns(func);
|
||||
compactRowWiseWvmmRuns(func);
|
||||
compactScalarChannelRuns(func, nextChannelId);
|
||||
compactBatchChannelRuns(func);
|
||||
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
func.walk([&](OpTy op) { ops.push_back(op); });
|
||||
for (auto op : llvm::reverse(ops))
|
||||
if (op->use_empty())
|
||||
op.erase();
|
||||
};
|
||||
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||
eraseUnusedOps(spatial::SpatConcatOp {});
|
||||
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||
|
||||
if (!sortTopologically(&func.getBody().front())) {
|
||||
func.emitOpError("failed to topologically order merged Spatial IR");
|
||||
signalPassFailure();
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <tuple>
|
||||
|
||||
#include "RegularOpCompaction.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -41,133 +42,75 @@ struct RegularChunk {
|
||||
Value output;
|
||||
};
|
||||
|
||||
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||
packedShape[0] *= count;
|
||||
return RankedTensorType::get(packedShape, elementType.getElementType());
|
||||
}
|
||||
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
||||
if (values.empty() || !values.front().hasOneUse())
|
||||
return {};
|
||||
|
||||
static Value
|
||||
extractPackedChunk(Value packedValue, RankedTensorType chunkType, unsigned index, IRRewriter& rewriter, Location loc) {
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(chunkType.getRank());
|
||||
sizes.reserve(chunkType.getRank());
|
||||
strides.reserve(chunkType.getRank());
|
||||
OpOperand& firstUse = *values.front().getUses().begin();
|
||||
auto concatOp = dyn_cast<spatial::SpatConcatOp>(firstUse.getOwner());
|
||||
if (!concatOp)
|
||||
return {};
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
|
||||
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(0)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(dim)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
startOperandIndex = firstUse.getOperandNumber();
|
||||
for (auto [index, value] : llvm::enumerate(values)) {
|
||||
if (!value.hasOneUse())
|
||||
return {};
|
||||
OpOperand& use = *value.getUses().begin();
|
||||
if (use.getOwner() != concatOp || use.getOperandNumber() != startOperandIndex + index)
|
||||
return {};
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
|
||||
return concatOp;
|
||||
}
|
||||
|
||||
static Value createPackedExtractRowsSlice(
|
||||
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
||||
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
int64_t rowsPerValue = rowType.getDimSize(0);
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(inputType.getRank());
|
||||
sizes.reserve(inputType.getRank());
|
||||
strides.reserve(inputType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
||||
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
static void replaceConcatRunWithPackedValue(spatial::SpatConcatOp concatOp,
|
||||
unsigned startOperandIndex,
|
||||
unsigned operandCount,
|
||||
Value packedValue,
|
||||
IRRewriter& rewriter) {
|
||||
SmallVector<Value> newInputs;
|
||||
newInputs.reserve(concatOp.getInputs().size() - operandCount + 1);
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
|
||||
if (operandIndex == startOperandIndex)
|
||||
newInputs.push_back(packedValue);
|
||||
if (operandIndex < startOperandIndex || operandIndex >= startOperandIndex + operandCount)
|
||||
newInputs.push_back(operand);
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
if (newInputs.size() == 1 && newInputs.front().getType() == concatOp.getOutput().getType()) {
|
||||
rewriter.replaceOp(concatOp, newInputs.front());
|
||||
return;
|
||||
}
|
||||
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newInputs); });
|
||||
}
|
||||
|
||||
static Value createPackedExtractSliceTensor(ValueRange values, IRRewriter& rewriter, Location loc) {
|
||||
if (values.empty())
|
||||
return {};
|
||||
if (values.size() == 1)
|
||||
return values.front();
|
||||
|
||||
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!firstSliceOp)
|
||||
static RankedTensorType
|
||||
getPackedConcatSliceType(spatial::SpatConcatOp concatOp, unsigned startOperandIndex, unsigned operandCount) {
|
||||
auto firstType = dyn_cast<RankedTensorType>(concatOp.getInputs()[startOperandIndex].getType());
|
||||
if (!firstType || !firstType.hasStaticShape())
|
||||
return {};
|
||||
|
||||
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
|
||||
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
|
||||
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|
||||
|| firstType.getRank() == 0)
|
||||
int64_t axis = concatOp.getAxis();
|
||||
if (axis < 0 || axis >= firstType.getRank())
|
||||
return {};
|
||||
|
||||
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
||||
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
};
|
||||
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
|
||||
return {};
|
||||
|
||||
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
|
||||
int64_t rowsPerValue = firstSizes[0];
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
for (size_t index = 1; index < values.size(); ++index) {
|
||||
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|
||||
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|
||||
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(sliceOp.getStaticStrides()))
|
||||
SmallVector<int64_t> shape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
shape[axis] = 0;
|
||||
for (unsigned index = 0; index < operandCount; ++index) {
|
||||
auto operandType = dyn_cast<RankedTensorType>(concatOp.getInputs()[startOperandIndex + index].getType());
|
||||
if (!operandType || !operandType.hasStaticShape() || operandType.getRank() != firstType.getRank())
|
||||
return {};
|
||||
|
||||
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
|
||||
return {};
|
||||
|
||||
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
|
||||
return {};
|
||||
|
||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
|
||||
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
|
||||
for (int64_t dim = 0; dim < firstType.getRank(); ++dim) {
|
||||
if (dim == axis)
|
||||
continue;
|
||||
if (operandType.getShape()[dim] != shape[dim])
|
||||
return {};
|
||||
}
|
||||
|
||||
shape[axis] += operandType.getShape()[axis];
|
||||
}
|
||||
|
||||
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(firstType.getRank());
|
||||
sizes.reserve(firstType.getRank());
|
||||
strides.reserve(firstType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(firstOffsets[0]));
|
||||
sizes.push_back(rewriter.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
|
||||
strides.push_back(rewriter.getIndexAttr(firstStrides[0]));
|
||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(firstOffsets[dim]));
|
||||
sizes.push_back(rewriter.getIndexAttr(firstSizes[dim]));
|
||||
strides.push_back(rewriter.getIndexAttr(firstStrides[dim]));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
return RankedTensorType::get(shape, firstType.getElementType());
|
||||
}
|
||||
|
||||
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
|
||||
@@ -207,8 +150,7 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter
|
||||
return {};
|
||||
if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; }))
|
||||
return {};
|
||||
|
||||
return tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, values).getResult();
|
||||
return {};
|
||||
}
|
||||
|
||||
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
|
||||
@@ -346,11 +288,28 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
||||
scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult());
|
||||
}
|
||||
|
||||
for (auto [index, chunk] : llvm::enumerate(run)) {
|
||||
Value replacement = extractPackedChunk(
|
||||
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
|
||||
Value output = chunk.output;
|
||||
output.replaceAllUsesWith(replacement);
|
||||
SmallVector<Value> outputs;
|
||||
outputs.reserve(run.size());
|
||||
for (const RegularChunk& chunk : run)
|
||||
outputs.push_back(chunk.output);
|
||||
|
||||
unsigned concatStartIndex = 0;
|
||||
auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex);
|
||||
auto concatPackedType = concatOp
|
||||
? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
||||
: RankedTensorType {};
|
||||
|
||||
if (concatOp && concatPackedType == packedOutputType) {
|
||||
replaceConcatRunWithPackedValue(
|
||||
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), loop.getResult(0), rewriter);
|
||||
}
|
||||
else {
|
||||
for (auto [index, chunk] : llvm::enumerate(run)) {
|
||||
Value replacement = extractPackedChunk(
|
||||
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
|
||||
Value output = chunk.output;
|
||||
output.replaceAllUsesWith(replacement);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Operation*> opsToErase;
|
||||
@@ -412,7 +371,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
||||
SmallVector<Value> sortedOutputs;
|
||||
sortedOutputs.reserve(sortedEntries.size());
|
||||
for (ReceiveEntry& entry : sortedEntries)
|
||||
sortedOutputs.push_back(entry.op.getOutput());
|
||||
|
||||
unsigned concatStartIndex = 0;
|
||||
auto concatOp = getContiguousConcatUse(ValueRange(sortedOutputs), concatStartIndex);
|
||||
auto concatPackedType =
|
||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
|
||||
: RankedTensorType {};
|
||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||
@@ -421,9 +391,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
|
||||
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
||||
if (concatOp && concatPackedType) {
|
||||
replaceConcatRunWithPackedValue(concatOp,
|
||||
concatStartIndex,
|
||||
static_cast<unsigned>(sortedOutputs.size()),
|
||||
compactReceive.getOutput(),
|
||||
rewriter);
|
||||
}
|
||||
else {
|
||||
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
|
||||
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
||||
}
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
@@ -531,7 +510,18 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
||||
SmallVector<Value> outputs;
|
||||
outputs.reserve(run.size());
|
||||
for (auto op : run)
|
||||
outputs.push_back(op.getOutput());
|
||||
|
||||
unsigned concatStartIndex = 0;
|
||||
auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex);
|
||||
auto concatPackedType =
|
||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
||||
: RankedTensorType {};
|
||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
||||
@@ -540,9 +530,15 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
for (auto [index, op] : llvm::enumerate(run))
|
||||
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
||||
if (concatOp && concatPackedType) {
|
||||
replaceConcatRunWithPackedValue(
|
||||
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
||||
}
|
||||
else {
|
||||
for (auto [index, op] : llvm::enumerate(run))
|
||||
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
||||
}
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user