fast pim bufferization using tensors
All checks were successful
Validate Operations / validate-operations (push) Successful in 24m29s
All checks were successful
Validate Operations / validate-operations (push) Successful in 24m29s
This commit is contained in:
@@ -159,9 +159,7 @@ static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRe
|
||||
rewriter.eraseOp(sendManyOp);
|
||||
}
|
||||
|
||||
static SmallVector<Value> createManyEmptyTensorsLike(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
TypeRange outputTypes) {
|
||||
static SmallVector<Value> createManyEmptyTensorsLike(IRRewriter& rewriter, Location loc, TypeRange outputTypes) {
|
||||
SmallVector<Type> tensorTypes;
|
||||
tensorTypes.reserve(outputTypes.size());
|
||||
for (Type outputType : outputTypes)
|
||||
@@ -177,7 +175,8 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan
|
||||
sourceCoreIds.reserve(receiveManyOp.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : receiveManyOp.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
SmallVector<Value> outputBuffers = createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes());
|
||||
SmallVector<Value> outputBuffers =
|
||||
createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes());
|
||||
|
||||
auto receiveMany = PimReceiveManyOp::create(rewriter,
|
||||
receiveManyOp.getLoc(),
|
||||
@@ -199,10 +198,8 @@ static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendMa
|
||||
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
|
||||
for (Value input : sendManyBatchOp.getInputs())
|
||||
mappedInputs.push_back(mapper.lookup(input));
|
||||
pim::PimSendManyBatchOp::create(rewriter,
|
||||
sendManyBatchOp.getLoc(),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
ValueRange(mappedInputs));
|
||||
pim::PimSendManyBatchOp::create(
|
||||
rewriter, sendManyBatchOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), ValueRange(mappedInputs));
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
|
||||
@@ -272,6 +269,276 @@ static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
}
|
||||
}
|
||||
|
||||
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||
packedShape[0] *= count;
|
||||
return RankedTensorType::get(packedShape, elementType.getElementType());
|
||||
}
|
||||
|
||||
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
|
||||
if (values.empty())
|
||||
return false;
|
||||
|
||||
auto firstResult = dyn_cast<OpResult>(values.front());
|
||||
if (!firstResult)
|
||||
return false;
|
||||
|
||||
owner = firstResult.getOwner();
|
||||
startIndex = firstResult.getResultNumber();
|
||||
for (auto [index, value] : llvm::enumerate(values)) {
|
||||
auto result = dyn_cast<OpResult>(value);
|
||||
if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static Value createPackedExtractRowsSlice(
|
||||
pim::PimExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
||||
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
int64_t rowsPerValue = rowType.getDimSize(0);
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(inputType.getRank());
|
||||
sizes.reserve(inputType.getRank());
|
||||
strides.reserve(inputType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
||||
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) {
|
||||
Operation* owner = nullptr;
|
||||
unsigned startIndex = 0;
|
||||
if (!getContiguousOpResults(values, owner, startIndex))
|
||||
return {};
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(owner))
|
||||
return createPackedExtractRowsSlice(extractRowsOp, startIndex, static_cast<unsigned>(values.size()), rewriter, loc);
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
static Value createPackedReceiveTensor(
|
||||
pim::PimReceiveManyOp receiveManyOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(receiveManyOp.getOutputs()[startIndex].getType());
|
||||
if (!rowType || !rowType.hasStaticShape() || rowType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||
auto outputBuffer = tensor::EmptyOp::create(rewriter, loc, packedType.getShape(), packedType.getElementType());
|
||||
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(count);
|
||||
ArrayRef<int32_t> allSourceCoreIds = receiveManyOp.getSourceCoreIds();
|
||||
for (unsigned index = 0; index < count; ++index)
|
||||
sourceCoreIds.push_back(allSourceCoreIds[startIndex + index]);
|
||||
|
||||
return pim::PimReceiveTensorOp::create(
|
||||
rewriter, loc, packedType, outputBuffer.getResult(), rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static Value
|
||||
createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
Value packedInput = createPackedTensorForValues(mapOp.getInputs().slice(startIndex, count), rewriter, loc);
|
||||
if (!packedInput)
|
||||
return {};
|
||||
|
||||
auto inputType = dyn_cast<RankedTensorType>(mapOp.getInputs()[startIndex].getType());
|
||||
auto outputType = dyn_cast<RankedTensorType>(mapOp.getOutputs()[startIndex].getType());
|
||||
if (!inputType || !outputType || !inputType.hasStaticShape() || !outputType.hasStaticShape()
|
||||
|| inputType.getRank() == 0 || outputType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(count));
|
||||
auto packedInit =
|
||||
tensor::EmptyOp::create(rewriter, loc, packedOutputType.getShape(), packedOutputType.getElementType());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, loc, count);
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Block* loopBlock = loop.getBody();
|
||||
rewriter.setInsertionPointToStart(loopBlock);
|
||||
Value iv = loopBlock->getArgument(0);
|
||||
Value acc = loopBlock->getArgument(1);
|
||||
|
||||
int64_t inputRowsPerValue = inputType.getDimSize(0);
|
||||
Value inputRowOffset = iv;
|
||||
if (inputRowsPerValue != 1) {
|
||||
auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, inputRowsPerValue);
|
||||
inputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> extractOffsets;
|
||||
SmallVector<OpFoldResult> extractSizes;
|
||||
SmallVector<OpFoldResult> extractStrides;
|
||||
extractOffsets.push_back(inputRowOffset);
|
||||
extractSizes.push_back(rewriter.getIndexAttr(inputRowsPerValue));
|
||||
extractStrides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
extractOffsets.push_back(rewriter.getIndexAttr(0));
|
||||
extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||
extractStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
auto inputSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, inputType, packedInput, extractOffsets, extractSizes, extractStrides);
|
||||
|
||||
IRMapping mapping;
|
||||
Block& body = mapOp.getBody().front();
|
||||
mapping.map(body.getArgument(0), inputSlice.getResult());
|
||||
for (Operation& bodyOp : body.without_terminator()) {
|
||||
Operation* cloned = rewriter.clone(bodyOp, mapping);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults()))
|
||||
mapping.map(originalResult, clonedResult);
|
||||
rewriter.setInsertionPointAfter(cloned);
|
||||
}
|
||||
|
||||
auto yieldOp = cast<pim::PimYieldOp>(body.getTerminator());
|
||||
Value mappedOutput = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
||||
|
||||
int64_t outputRowsPerValue = outputType.getDimSize(0);
|
||||
Value outputRowOffset = iv;
|
||||
if (outputRowsPerValue != 1) {
|
||||
auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, outputRowsPerValue);
|
||||
outputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> insertOffsets;
|
||||
SmallVector<OpFoldResult> insertSizes;
|
||||
SmallVector<OpFoldResult> insertStrides;
|
||||
insertOffsets.push_back(outputRowOffset);
|
||||
insertSizes.push_back(rewriter.getIndexAttr(outputRowsPerValue));
|
||||
insertStrides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < outputType.getRank(); ++dim) {
|
||||
insertOffsets.push_back(rewriter.getIndexAttr(0));
|
||||
insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim)));
|
||||
insertStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
auto inserted =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, mappedOutput, acc, insertOffsets, insertSizes, insertStrides);
|
||||
scf::YieldOp::create(rewriter, loc, inserted.getResult());
|
||||
}
|
||||
|
||||
return loop.getResult(0);
|
||||
}
|
||||
|
||||
static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<pim::PimSendManyOp> sendManyOps;
|
||||
funcOp.walk([&](pim::PimSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); });
|
||||
for (auto sendManyOp : sendManyOps) {
|
||||
if (sendManyOp.getInputs().empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(sendManyOp);
|
||||
Value packedInput = createPackedTensorForValues(sendManyOp.getInputs(), rewriter, sendManyOp.getLoc());
|
||||
if (!packedInput)
|
||||
continue;
|
||||
|
||||
pim::PimSendTensorOp::create(rewriter, sendManyOp.getLoc(), packedInput, sendManyOp.getTargetCoreIdsAttr());
|
||||
rewriter.eraseOp(sendManyOp);
|
||||
}
|
||||
|
||||
SmallVector<pim::PimConcatOp> concatOps;
|
||||
funcOp.walk([&](pim::PimConcatOp concatOp) { concatOps.push_back(concatOp); });
|
||||
for (auto concatOp : concatOps) {
|
||||
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
|
||||
continue;
|
||||
|
||||
SmallVector<Value> packedInputs;
|
||||
bool changed = false;
|
||||
rewriter.setInsertionPoint(concatOp);
|
||||
|
||||
for (unsigned index = 0; index < concatOp.getInputs().size();) {
|
||||
Value input = concatOp.getInputs()[index];
|
||||
auto result = dyn_cast<OpResult>(input);
|
||||
if (!result) {
|
||||
packedInputs.push_back(input);
|
||||
++index;
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* owner = result.getOwner();
|
||||
unsigned startIndex = result.getResultNumber();
|
||||
unsigned endIndex = index + 1;
|
||||
while (endIndex < concatOp.getInputs().size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
|
||||
if (!nextResult || nextResult.getOwner() != owner
|
||||
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
|
||||
break;
|
||||
++endIndex;
|
||||
}
|
||||
|
||||
unsigned count = endIndex - index;
|
||||
Value packedInput;
|
||||
if (auto mapOp = dyn_cast<pim::PimMapOp>(owner))
|
||||
packedInput = createPackedMapTensor(mapOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(owner))
|
||||
packedInput = createPackedReceiveTensor(receiveManyOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(owner))
|
||||
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
|
||||
if (packedInput) {
|
||||
packedInputs.push_back(packedInput);
|
||||
changed = true;
|
||||
}
|
||||
else {
|
||||
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
|
||||
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
|
||||
}
|
||||
|
||||
index = endIndex;
|
||||
}
|
||||
|
||||
if (!changed)
|
||||
continue;
|
||||
|
||||
auto newConcat = pim::PimConcatOp::create(rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
concatOp.getAxisAttr(),
|
||||
ValueRange(packedInputs),
|
||||
concatOp.getOutputBuffer());
|
||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||
}
|
||||
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||
for (auto op : llvm::reverse(ops))
|
||||
if (op->use_empty())
|
||||
rewriter.eraseOp(op);
|
||||
};
|
||||
eraseUnusedOps(pim::PimMapOp {});
|
||||
eraseUnusedOps(pim::PimReceiveManyOp {});
|
||||
eraseUnusedOps(pim::PimExtractRowsOp {});
|
||||
eraseUnusedOps(pim::PimEmptyManyOp {});
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
SmallVectorImpl<Operation*>& helperChain,
|
||||
bool requireReturnUse = true) {
|
||||
@@ -399,21 +666,21 @@ static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
|
||||
}
|
||||
|
||||
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
auto getConcatResult = [](Operation *op) -> Value {
|
||||
auto getConcatResult = [](Operation* op) -> Value {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getResult();
|
||||
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||
return pimConcat.getOutput();
|
||||
return {};
|
||||
};
|
||||
auto getConcatAxis = [](Operation *op) -> std::optional<int64_t> {
|
||||
auto getConcatAxis = [](Operation* op) -> std::optional<int64_t> {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getDim();
|
||||
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||
return pimConcat.getAxis();
|
||||
return std::nullopt;
|
||||
};
|
||||
auto getConcatOperands = [](Operation *op) -> OperandRange {
|
||||
auto getConcatOperands = [](Operation* op) -> OperandRange {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getOperands();
|
||||
return cast<pim::PimConcatOp>(op).getInputs();
|
||||
@@ -799,6 +1066,8 @@ void SpatialToPimPass::runOnOperation() {
|
||||
for (auto extractRowsOp : remainingExtractRowsOps)
|
||||
lowerExtractRows(extractRowsOp, rewriter);
|
||||
|
||||
compactPimTensorGroups(funcOp, rewriter);
|
||||
|
||||
// Dump to file for debug
|
||||
bool hasSpatialOps = false;
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
|
||||
Reference in New Issue
Block a user