compact pim IR
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m15s
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m15s
This commit is contained in:
@@ -150,116 +150,105 @@ static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewri
|
||||
|
||||
static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) {
|
||||
rewriter.setInsertionPoint(sendManyOp);
|
||||
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) {
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, input);
|
||||
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
PimSendOp::create(rewriter, sendManyOp.getLoc(), input, sizeAttr, targetCoreIdAttr);
|
||||
}
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : sendManyOp.getTargetCoreIds())
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
PimSendManyOp::create(
|
||||
rewriter, sendManyOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), sendManyOp.getInputs());
|
||||
rewriter.eraseOp(sendManyOp);
|
||||
}
|
||||
|
||||
static SmallVector<Value> createManyEmptyTensorsLike(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
TypeRange outputTypes) {
|
||||
SmallVector<Type> tensorTypes;
|
||||
tensorTypes.reserve(outputTypes.size());
|
||||
for (Type outputType : outputTypes)
|
||||
tensorTypes.push_back(outputType);
|
||||
|
||||
auto emptyMany = pim::PimEmptyManyOp::create(rewriter, loc, TypeRange(tensorTypes));
|
||||
return SmallVector<Value>(emptyMany.getOutputs().begin(), emptyMany.getOutputs().end());
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) {
|
||||
SmallVector<Value> replacements;
|
||||
replacements.reserve(receiveManyOp.getNumResults());
|
||||
|
||||
rewriter.setInsertionPoint(receiveManyOp);
|
||||
for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) {
|
||||
auto outputType = cast<ShapedType>(output.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
||||
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
Value received =
|
||||
PimReceiveOp::create(
|
||||
rewriter, receiveManyOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
replacements.push_back(received);
|
||||
}
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(receiveManyOp.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : receiveManyOp.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
SmallVector<Value> outputBuffers = createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes());
|
||||
|
||||
rewriter.replaceOp(receiveManyOp, ValueRange(replacements));
|
||||
auto receiveMany = PimReceiveManyOp::create(rewriter,
|
||||
receiveManyOp.getLoc(),
|
||||
receiveManyOp.getResultTypes(),
|
||||
ValueRange(outputBuffers),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds));
|
||||
rewriter.replaceOp(receiveManyOp, receiveMany.getOutputs());
|
||||
}
|
||||
|
||||
static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp,
|
||||
int32_t laneCount,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
auto targetCoreIds = sendManyBatchOp.getTargetCoreIds();
|
||||
for (auto [valueIndex, input] : llvm::enumerate(sendManyBatchOp.getInputs())) {
|
||||
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
|
||||
auto targetSlice = targetCoreIds.slice(metadataOffset, laneCount);
|
||||
pim::PimSendBatchOp::create(rewriter,
|
||||
sendManyBatchOp.getLoc(),
|
||||
mapper.lookup(input),
|
||||
getTensorSizeInBytesAttr(rewriter, mapper.lookup(input)),
|
||||
rewriter.getDenseI32ArrayAttr(targetSlice));
|
||||
}
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(sendManyBatchOp.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : sendManyBatchOp.getTargetCoreIds())
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
SmallVector<Value> mappedInputs;
|
||||
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
|
||||
for (Value input : sendManyBatchOp.getInputs())
|
||||
mappedInputs.push_back(mapper.lookup(input));
|
||||
pim::PimSendManyBatchOp::create(rewriter,
|
||||
sendManyBatchOp.getLoc(),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
ValueRange(mappedInputs));
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
|
||||
int32_t laneCount,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
auto sourceCoreIds = receiveManyBatchOp.getSourceCoreIds();
|
||||
for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) {
|
||||
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
|
||||
auto sourceSlice = sourceCoreIds.slice(metadataOffset, laneCount);
|
||||
auto outputType = cast<ShapedType>(output.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType);
|
||||
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
||||
receiveManyBatchOp.getLoc(),
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, output),
|
||||
rewriter.getDenseI32ArrayAttr(sourceSlice))
|
||||
.getOutput();
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(receiveManyBatchOp.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : receiveManyBatchOp.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
SmallVector<Value> outputBuffers =
|
||||
createManyEmptyTensorsLike(rewriter, receiveManyBatchOp.getLoc(), receiveManyBatchOp.getResultTypes());
|
||||
|
||||
auto receiveMany = pim::PimReceiveManyBatchOp::create(rewriter,
|
||||
receiveManyBatchOp.getLoc(),
|
||||
receiveManyBatchOp.getResultTypes(),
|
||||
ValueRange(outputBuffers),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds));
|
||||
for (auto [output, received] : llvm::zip(receiveManyBatchOp.getOutputs(), receiveMany.getOutputs()))
|
||||
mapper.map(output, received);
|
||||
}
|
||||
}
|
||||
|
||||
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
||||
Value input = extractRowsOp.getInput();
|
||||
RankedTensorType inputType;
|
||||
if (auto tensorType = dyn_cast<RankedTensorType>(input.getType())) {
|
||||
inputType = tensorType;
|
||||
}
|
||||
else if (auto memRefType = dyn_cast<MemRefType>(input.getType())) {
|
||||
inputType = RankedTensorType::get(memRefType.getShape(), memRefType.getElementType());
|
||||
rewriter.setInsertionPoint(extractRowsOp);
|
||||
input = bufferization::ToTensorOp::create(
|
||||
rewriter, extractRowsOp.getLoc(), inputType, input, rewriter.getUnitAttr(), rewriter.getUnitAttr())
|
||||
.getResult();
|
||||
}
|
||||
else {
|
||||
extractRowsOp.emitOpError("requires a ranked tensor or memref input during Spatial-to-PIM lowering");
|
||||
return;
|
||||
}
|
||||
int64_t numCols = inputType.getDimSize(1);
|
||||
|
||||
SmallVector<Value> replacements;
|
||||
replacements.reserve(extractRowsOp.getNumResults());
|
||||
|
||||
rewriter.setInsertionPoint(extractRowsOp);
|
||||
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
|
||||
auto outputType = dyn_cast<RankedTensorType>(output.getType());
|
||||
if (!outputType) {
|
||||
extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering");
|
||||
return;
|
||||
}
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)),
|
||||
rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto rowSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
|
||||
replacements.push_back(rowSlice.getResult());
|
||||
}
|
||||
SmallVector<Value> outputBuffers =
|
||||
createManyEmptyTensorsLike(rewriter, extractRowsOp.getLoc(), extractRowsOp.getResultTypes());
|
||||
|
||||
rewriter.replaceOp(extractRowsOp, ValueRange(replacements));
|
||||
auto extractRows = pim::PimExtractRowsOp::create(rewriter,
|
||||
extractRowsOp.getLoc(),
|
||||
extractRowsOp.getResultTypes(),
|
||||
extractRowsOp.getInput(),
|
||||
ValueRange(outputBuffers));
|
||||
rewriter.replaceOp(extractRowsOp, extractRows.getOutputs());
|
||||
}
|
||||
|
||||
static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
|
||||
rewriter.setInsertionPoint(concatOp);
|
||||
Value concatenated =
|
||||
tensor::ConcatOp::create(rewriter, concatOp.getLoc(), concatOp.getAxis(), concatOp.getInputs()).getResult();
|
||||
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
|
||||
Value outputBuffer = createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), outputType).getResult();
|
||||
Value concatenated = pim::PimConcatOp::create(rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
rewriter.getI64IntegerAttr(concatOp.getAxis()),
|
||||
concatOp.getInputs(),
|
||||
outputBuffer)
|
||||
.getOutput();
|
||||
rewriter.replaceOp(concatOp, concatenated);
|
||||
}
|
||||
|
||||
@@ -282,34 +271,23 @@ static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewrit
|
||||
}
|
||||
}
|
||||
|
||||
static void expandMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatMapOp> mapOps;
|
||||
funcOp.walk([&](spatial::SpatMapOp mapOp) { mapOps.push_back(mapOp); });
|
||||
funcOp.walk([&](spatial::SpatMapOp mapOp) {
|
||||
if (mapOp->getParentOfType<pim::PimCoreOp>() || mapOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
mapOps.push_back(mapOp);
|
||||
});
|
||||
|
||||
for (auto mapOp : mapOps) {
|
||||
Block& body = mapOp.getBody().front();
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
|
||||
|
||||
SmallVector<Value> replacements;
|
||||
replacements.reserve(mapOp.getInputs().size());
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
for (Value input : mapOp.getInputs()) {
|
||||
IRMapping mapping;
|
||||
mapping.map(body.getArgument(0), input);
|
||||
auto pimMap = pim::PimMapOp::create(rewriter, mapOp.getLoc(), mapOp.getResultTypes(), mapOp.getInputs());
|
||||
rewriter.inlineRegionBefore(mapOp.getBody(), pimMap.getBody(), pimMap.getBody().begin());
|
||||
|
||||
Value replacement = input;
|
||||
for (Operation& op : body.without_terminator()) {
|
||||
Operation* cloned = rewriter.clone(op, mapping);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapping.map(originalResult, clonedResult);
|
||||
rewriter.setInsertionPointAfter(cloned);
|
||||
}
|
||||
|
||||
replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
||||
replacements.push_back(replacement);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(mapOp, replacements);
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
rewriter.replaceOpWithNewOp<pim::PimYieldOp>(yieldOp, yieldOp.getOutputs());
|
||||
rewriter.replaceOp(mapOp, pimMap.getOutputs());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -440,8 +418,28 @@ static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
|
||||
}
|
||||
|
||||
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
auto getConcatResult = [](Operation *op) -> Value {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getResult();
|
||||
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||
return pimConcat.getOutput();
|
||||
return {};
|
||||
};
|
||||
auto getConcatAxis = [](Operation *op) -> std::optional<int64_t> {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getDim();
|
||||
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||
return pimConcat.getAxis();
|
||||
return std::nullopt;
|
||||
};
|
||||
auto getConcatOperands = [](Operation *op) -> OperandRange {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getOperands();
|
||||
return cast<pim::PimConcatOp>(op).getInputs();
|
||||
};
|
||||
|
||||
auto uses = value.getUses();
|
||||
if (rangeLength(uses) != 1 || !isa<tensor::ConcatOp>(uses.begin()->getOwner()))
|
||||
if (rangeLength(uses) != 1 || !isa<tensor::ConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
|
||||
return std::nullopt;
|
||||
|
||||
auto valueType = dyn_cast<ShapedType>(value.getType());
|
||||
@@ -453,18 +451,19 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
Value currentValue = value;
|
||||
Operation* currentUser = uses.begin()->getOwner();
|
||||
|
||||
while (auto concatOp = dyn_cast<tensor::ConcatOp>(currentUser)) {
|
||||
while (isa<tensor::ConcatOp, pim::PimConcatOp>(currentUser)) {
|
||||
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
|
||||
int64_t axis = concatOp.getDim();
|
||||
for (Value operand : concatOp.getOperands().take_front(operandIndex))
|
||||
int64_t axis = *getConcatAxis(currentUser);
|
||||
for (Value operand : getConcatOperands(currentUser).take_front(operandIndex))
|
||||
sliceOffsets[axis] += cast<ShapedType>(operand.getType()).getShape()[axis];
|
||||
|
||||
auto concatType = dyn_cast<ShapedType>(concatOp.getResult().getType());
|
||||
Value concatResult = getConcatResult(currentUser);
|
||||
auto concatType = dyn_cast<ShapedType>(concatResult.getType());
|
||||
if (!concatType || !concatType.hasStaticShape())
|
||||
return std::nullopt;
|
||||
concatShape.assign(concatType.getShape().begin(), concatType.getShape().end());
|
||||
|
||||
currentValue = concatOp.getResult();
|
||||
currentValue = concatResult;
|
||||
auto currentUses = currentValue.getUses();
|
||||
if (rangeLength(currentUses) != 1)
|
||||
return std::nullopt;
|
||||
@@ -638,7 +637,6 @@ void SpatialToPimPass::runOnOperation() {
|
||||
func::FuncOp funcOp = *entryFunc;
|
||||
|
||||
IRRewriter rewriter(&getContext());
|
||||
expandMapOps(funcOp, rewriter);
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect,
|
||||
@@ -687,6 +685,8 @@ void SpatialToPimPass::runOnOperation() {
|
||||
runOnComputeBatchOp(computeBatchOp, rewriter);
|
||||
}
|
||||
|
||||
lowerMapOps(funcOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
||||
receiveOps.push_back(op);
|
||||
@@ -1317,7 +1317,8 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
||||
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
||||
Operation* onlyUser = *op->getUsers().begin();
|
||||
isExclusivelyOwnedByReturnChain =
|
||||
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatCompute>(onlyUser) || isChannelUseChainOp(onlyUser);
|
||||
isa<func::ReturnOp, tensor::ConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|
||||
|| isChannelUseChainOp(onlyUser);
|
||||
}
|
||||
if (!isExclusivelyOwnedByReturnChain)
|
||||
return;
|
||||
@@ -1341,6 +1342,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getOperands())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user