fast pim bufferization using tensors
All checks were successful
Validate Operations / validate-operations (push) Successful in 24m29s
All checks were successful
Validate Operations / validate-operations (push) Successful in 24m29s
This commit is contained in:
@@ -178,7 +178,6 @@ void PimMemory::report(llvm::raw_ostream& file) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void PimMemory::remove(mlir::Value val) {
|
||||
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
|
||||
globalMemEntriesMap.erase(removeIter);
|
||||
@@ -370,11 +369,21 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
|
||||
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const {
|
||||
for (auto [outputBuffer, sourceCoreId] : llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds()))
|
||||
void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp,
|
||||
const StaticValueKnowledge& knowledge) const {
|
||||
for (auto [outputBuffer, sourceCoreId] :
|
||||
llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds()))
|
||||
emitCommunicationOp("recv", addressOf(outputBuffer, knowledge), sourceCoreId, getValueSizeInBytes(outputBuffer));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||
const StaticValueKnowledge& knowledge) const {
|
||||
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
|
||||
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size();
|
||||
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
|
||||
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
|
||||
}
|
||||
@@ -384,7 +393,15 @@ void PimCodeGen::codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticVa
|
||||
emitCommunicationOp("send", addressOf(input, knowledge), targetCoreId, getValueSizeInBytes(input));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const {
|
||||
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
||||
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
|
||||
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
|
||||
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
|
||||
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp,
|
||||
const StaticValueKnowledge& knowledge) const {
|
||||
auto inputType = cast<ShapedType>(extractRowsOp.getInput().getType());
|
||||
assert(inputType.hasStaticShape() && inputType.getRank() == 2 && "extract_rows codegen requires static rank-2 input");
|
||||
|
||||
@@ -393,13 +410,8 @@ void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const
|
||||
size_t inputAddr = addressOf(extractRowsOp.getInput(), knowledge);
|
||||
|
||||
for (auto [rowIndex, outputBuffer] : llvm::enumerate(extractRowsOp.getOutputBuffers()))
|
||||
emitMemCopyOp("lmv",
|
||||
addressOf(outputBuffer, knowledge),
|
||||
0,
|
||||
inputAddr,
|
||||
rowIndex * rowSizeInBytes,
|
||||
rowSizeInBytes,
|
||||
"len");
|
||||
emitMemCopyOp(
|
||||
"lmv", addressOf(outputBuffer, knowledge), 0, inputAddr, rowIndex * rowSizeInBytes, rowSizeInBytes, "len");
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -742,10 +754,8 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
|
||||
for (mlir::Value input : sendManyBatchOp.getInputs())
|
||||
mappedInputs.push_back(mapper.lookup(input));
|
||||
|
||||
pim::PimSendManyOp::create(builder,
|
||||
sendManyBatchOp.getLoc(),
|
||||
builder.getDenseI32ArrayAttr(laneTargetCoreIds),
|
||||
ValueRange(mappedInputs));
|
||||
pim::PimSendManyOp::create(
|
||||
builder, sendManyBatchOp.getLoc(), builder.getDenseI32ArrayAttr(laneTargetCoreIds), ValueRange(mappedInputs));
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -773,13 +783,13 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
|
||||
for (mlir::Value outputBuffer : receiveManyBatchOp.getOutputBuffers())
|
||||
mappedOutputBuffers.push_back(mapper.lookup(outputBuffer));
|
||||
|
||||
auto scalarReceiveMany =
|
||||
pim::PimReceiveManyOp::create(builder,
|
||||
auto scalarReceiveMany = pim::PimReceiveManyOp::create(builder,
|
||||
receiveManyBatchOp.getLoc(),
|
||||
receiveManyBatchOp->getResultTypes(),
|
||||
ValueRange(mappedOutputBuffers),
|
||||
builder.getDenseI32ArrayAttr(laneSourceCoreIds));
|
||||
for (auto [originalOutput, scalarOutput] : llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs()))
|
||||
for (auto [originalOutput, scalarOutput] :
|
||||
llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs()))
|
||||
mapper.map(originalOutput, scalarOutput);
|
||||
continue;
|
||||
}
|
||||
@@ -904,10 +914,14 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
|
||||
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(op))
|
||||
coreCodeGen.codeGenReceiveManyOp(receiveManyOp, knowledge);
|
||||
else if (auto receiveTensorOp = dyn_cast<pim::PimReceiveTensorOp>(op))
|
||||
coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge);
|
||||
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
||||
coreCodeGen.codeGenSendOp(sendOp, knowledge);
|
||||
else if (auto sendManyOp = dyn_cast<pim::PimSendManyOp>(op))
|
||||
coreCodeGen.codeGenSendManyOp(sendManyOp, knowledge);
|
||||
else if (auto sendTensorOp = dyn_cast<pim::PimSendTensorOp>(op))
|
||||
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
|
||||
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op))
|
||||
coreCodeGen.codeGenExtractRowsOp(extractRowsOp, knowledge);
|
||||
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
@@ -117,8 +118,10 @@ public:
|
||||
|
||||
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
|
||||
|
||||
|
||||
@@ -159,9 +159,7 @@ static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRe
|
||||
rewriter.eraseOp(sendManyOp);
|
||||
}
|
||||
|
||||
static SmallVector<Value> createManyEmptyTensorsLike(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
TypeRange outputTypes) {
|
||||
static SmallVector<Value> createManyEmptyTensorsLike(IRRewriter& rewriter, Location loc, TypeRange outputTypes) {
|
||||
SmallVector<Type> tensorTypes;
|
||||
tensorTypes.reserve(outputTypes.size());
|
||||
for (Type outputType : outputTypes)
|
||||
@@ -177,7 +175,8 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan
|
||||
sourceCoreIds.reserve(receiveManyOp.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : receiveManyOp.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
SmallVector<Value> outputBuffers = createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes());
|
||||
SmallVector<Value> outputBuffers =
|
||||
createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes());
|
||||
|
||||
auto receiveMany = PimReceiveManyOp::create(rewriter,
|
||||
receiveManyOp.getLoc(),
|
||||
@@ -199,10 +198,8 @@ static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendMa
|
||||
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
|
||||
for (Value input : sendManyBatchOp.getInputs())
|
||||
mappedInputs.push_back(mapper.lookup(input));
|
||||
pim::PimSendManyBatchOp::create(rewriter,
|
||||
sendManyBatchOp.getLoc(),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
ValueRange(mappedInputs));
|
||||
pim::PimSendManyBatchOp::create(
|
||||
rewriter, sendManyBatchOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), ValueRange(mappedInputs));
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
|
||||
@@ -272,6 +269,276 @@ static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
}
|
||||
}
|
||||
|
||||
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||
packedShape[0] *= count;
|
||||
return RankedTensorType::get(packedShape, elementType.getElementType());
|
||||
}
|
||||
|
||||
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
|
||||
if (values.empty())
|
||||
return false;
|
||||
|
||||
auto firstResult = dyn_cast<OpResult>(values.front());
|
||||
if (!firstResult)
|
||||
return false;
|
||||
|
||||
owner = firstResult.getOwner();
|
||||
startIndex = firstResult.getResultNumber();
|
||||
for (auto [index, value] : llvm::enumerate(values)) {
|
||||
auto result = dyn_cast<OpResult>(value);
|
||||
if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static Value createPackedExtractRowsSlice(
|
||||
pim::PimExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
||||
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
int64_t rowsPerValue = rowType.getDimSize(0);
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(inputType.getRank());
|
||||
sizes.reserve(inputType.getRank());
|
||||
strides.reserve(inputType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
||||
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) {
|
||||
Operation* owner = nullptr;
|
||||
unsigned startIndex = 0;
|
||||
if (!getContiguousOpResults(values, owner, startIndex))
|
||||
return {};
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(owner))
|
||||
return createPackedExtractRowsSlice(extractRowsOp, startIndex, static_cast<unsigned>(values.size()), rewriter, loc);
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
static Value createPackedReceiveTensor(
|
||||
pim::PimReceiveManyOp receiveManyOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(receiveManyOp.getOutputs()[startIndex].getType());
|
||||
if (!rowType || !rowType.hasStaticShape() || rowType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||
auto outputBuffer = tensor::EmptyOp::create(rewriter, loc, packedType.getShape(), packedType.getElementType());
|
||||
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(count);
|
||||
ArrayRef<int32_t> allSourceCoreIds = receiveManyOp.getSourceCoreIds();
|
||||
for (unsigned index = 0; index < count; ++index)
|
||||
sourceCoreIds.push_back(allSourceCoreIds[startIndex + index]);
|
||||
|
||||
return pim::PimReceiveTensorOp::create(
|
||||
rewriter, loc, packedType, outputBuffer.getResult(), rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static Value
|
||||
createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
Value packedInput = createPackedTensorForValues(mapOp.getInputs().slice(startIndex, count), rewriter, loc);
|
||||
if (!packedInput)
|
||||
return {};
|
||||
|
||||
auto inputType = dyn_cast<RankedTensorType>(mapOp.getInputs()[startIndex].getType());
|
||||
auto outputType = dyn_cast<RankedTensorType>(mapOp.getOutputs()[startIndex].getType());
|
||||
if (!inputType || !outputType || !inputType.hasStaticShape() || !outputType.hasStaticShape()
|
||||
|| inputType.getRank() == 0 || outputType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(count));
|
||||
auto packedInit =
|
||||
tensor::EmptyOp::create(rewriter, loc, packedOutputType.getShape(), packedOutputType.getElementType());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, loc, count);
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Block* loopBlock = loop.getBody();
|
||||
rewriter.setInsertionPointToStart(loopBlock);
|
||||
Value iv = loopBlock->getArgument(0);
|
||||
Value acc = loopBlock->getArgument(1);
|
||||
|
||||
int64_t inputRowsPerValue = inputType.getDimSize(0);
|
||||
Value inputRowOffset = iv;
|
||||
if (inputRowsPerValue != 1) {
|
||||
auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, inputRowsPerValue);
|
||||
inputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> extractOffsets;
|
||||
SmallVector<OpFoldResult> extractSizes;
|
||||
SmallVector<OpFoldResult> extractStrides;
|
||||
extractOffsets.push_back(inputRowOffset);
|
||||
extractSizes.push_back(rewriter.getIndexAttr(inputRowsPerValue));
|
||||
extractStrides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
extractOffsets.push_back(rewriter.getIndexAttr(0));
|
||||
extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||
extractStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
auto inputSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, inputType, packedInput, extractOffsets, extractSizes, extractStrides);
|
||||
|
||||
IRMapping mapping;
|
||||
Block& body = mapOp.getBody().front();
|
||||
mapping.map(body.getArgument(0), inputSlice.getResult());
|
||||
for (Operation& bodyOp : body.without_terminator()) {
|
||||
Operation* cloned = rewriter.clone(bodyOp, mapping);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults()))
|
||||
mapping.map(originalResult, clonedResult);
|
||||
rewriter.setInsertionPointAfter(cloned);
|
||||
}
|
||||
|
||||
auto yieldOp = cast<pim::PimYieldOp>(body.getTerminator());
|
||||
Value mappedOutput = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
||||
|
||||
int64_t outputRowsPerValue = outputType.getDimSize(0);
|
||||
Value outputRowOffset = iv;
|
||||
if (outputRowsPerValue != 1) {
|
||||
auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, outputRowsPerValue);
|
||||
outputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> insertOffsets;
|
||||
SmallVector<OpFoldResult> insertSizes;
|
||||
SmallVector<OpFoldResult> insertStrides;
|
||||
insertOffsets.push_back(outputRowOffset);
|
||||
insertSizes.push_back(rewriter.getIndexAttr(outputRowsPerValue));
|
||||
insertStrides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < outputType.getRank(); ++dim) {
|
||||
insertOffsets.push_back(rewriter.getIndexAttr(0));
|
||||
insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim)));
|
||||
insertStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
auto inserted =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, mappedOutput, acc, insertOffsets, insertSizes, insertStrides);
|
||||
scf::YieldOp::create(rewriter, loc, inserted.getResult());
|
||||
}
|
||||
|
||||
return loop.getResult(0);
|
||||
}
|
||||
|
||||
static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<pim::PimSendManyOp> sendManyOps;
|
||||
funcOp.walk([&](pim::PimSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); });
|
||||
for (auto sendManyOp : sendManyOps) {
|
||||
if (sendManyOp.getInputs().empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(sendManyOp);
|
||||
Value packedInput = createPackedTensorForValues(sendManyOp.getInputs(), rewriter, sendManyOp.getLoc());
|
||||
if (!packedInput)
|
||||
continue;
|
||||
|
||||
pim::PimSendTensorOp::create(rewriter, sendManyOp.getLoc(), packedInput, sendManyOp.getTargetCoreIdsAttr());
|
||||
rewriter.eraseOp(sendManyOp);
|
||||
}
|
||||
|
||||
SmallVector<pim::PimConcatOp> concatOps;
|
||||
funcOp.walk([&](pim::PimConcatOp concatOp) { concatOps.push_back(concatOp); });
|
||||
for (auto concatOp : concatOps) {
|
||||
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
|
||||
continue;
|
||||
|
||||
SmallVector<Value> packedInputs;
|
||||
bool changed = false;
|
||||
rewriter.setInsertionPoint(concatOp);
|
||||
|
||||
for (unsigned index = 0; index < concatOp.getInputs().size();) {
|
||||
Value input = concatOp.getInputs()[index];
|
||||
auto result = dyn_cast<OpResult>(input);
|
||||
if (!result) {
|
||||
packedInputs.push_back(input);
|
||||
++index;
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* owner = result.getOwner();
|
||||
unsigned startIndex = result.getResultNumber();
|
||||
unsigned endIndex = index + 1;
|
||||
while (endIndex < concatOp.getInputs().size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
|
||||
if (!nextResult || nextResult.getOwner() != owner
|
||||
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
|
||||
break;
|
||||
++endIndex;
|
||||
}
|
||||
|
||||
unsigned count = endIndex - index;
|
||||
Value packedInput;
|
||||
if (auto mapOp = dyn_cast<pim::PimMapOp>(owner))
|
||||
packedInput = createPackedMapTensor(mapOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(owner))
|
||||
packedInput = createPackedReceiveTensor(receiveManyOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(owner))
|
||||
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
|
||||
if (packedInput) {
|
||||
packedInputs.push_back(packedInput);
|
||||
changed = true;
|
||||
}
|
||||
else {
|
||||
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
|
||||
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
|
||||
}
|
||||
|
||||
index = endIndex;
|
||||
}
|
||||
|
||||
if (!changed)
|
||||
continue;
|
||||
|
||||
auto newConcat = pim::PimConcatOp::create(rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
concatOp.getAxisAttr(),
|
||||
ValueRange(packedInputs),
|
||||
concatOp.getOutputBuffer());
|
||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||
}
|
||||
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||
for (auto op : llvm::reverse(ops))
|
||||
if (op->use_empty())
|
||||
rewriter.eraseOp(op);
|
||||
};
|
||||
eraseUnusedOps(pim::PimMapOp {});
|
||||
eraseUnusedOps(pim::PimReceiveManyOp {});
|
||||
eraseUnusedOps(pim::PimExtractRowsOp {});
|
||||
eraseUnusedOps(pim::PimEmptyManyOp {});
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
SmallVectorImpl<Operation*>& helperChain,
|
||||
bool requireReturnUse = true) {
|
||||
@@ -399,21 +666,21 @@ static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
|
||||
}
|
||||
|
||||
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
auto getConcatResult = [](Operation *op) -> Value {
|
||||
auto getConcatResult = [](Operation* op) -> Value {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getResult();
|
||||
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||
return pimConcat.getOutput();
|
||||
return {};
|
||||
};
|
||||
auto getConcatAxis = [](Operation *op) -> std::optional<int64_t> {
|
||||
auto getConcatAxis = [](Operation* op) -> std::optional<int64_t> {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getDim();
|
||||
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||
return pimConcat.getAxis();
|
||||
return std::nullopt;
|
||||
};
|
||||
auto getConcatOperands = [](Operation *op) -> OperandRange {
|
||||
auto getConcatOperands = [](Operation* op) -> OperandRange {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getOperands();
|
||||
return cast<pim::PimConcatOp>(op).getInputs();
|
||||
@@ -799,6 +1066,8 @@ void SpatialToPimPass::runOnOperation() {
|
||||
for (auto extractRowsOp : remainingExtractRowsOps)
|
||||
lowerExtractRows(extractRowsOp, rewriter);
|
||||
|
||||
compactPimTensorGroups(funcOp, rewriter);
|
||||
|
||||
// Dump to file for debug
|
||||
bool hasSpatialOps = false;
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
|
||||
@@ -133,6 +133,18 @@ def PimSendManyOp : PimOp<"send_many", []> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimSendTensorOp : PimOp<"send_tensor", []> {
|
||||
let summary = "Send equal contiguous chunks of one tensor to target cores";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimSendBatchOp : PimOp<"send_batch", []> {
|
||||
let summary = "Send a per-lane tensor to target cores from a batched core";
|
||||
|
||||
@@ -203,6 +215,28 @@ def PimReceiveManyOp : PimOp<"receive_many", [DestinationStyleOpInterface]> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimReceiveTensorOp : PimOp<"receive_tensor", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive equal contiguous chunks from source cores into one tensor";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$outputBuffer,
|
||||
DenseI32ArrayAttr:$sourceCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive per-lane tensors from source cores into a batched core";
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -100,9 +100,9 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
||||
result.addAttribute("operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr(
|
||||
{static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
result.addAttribute(
|
||||
"operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreIds)
|
||||
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||
|
||||
@@ -267,6 +267,33 @@ ParseResult PimSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void PimSendTensorOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!targetCoreIds.empty())
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void PimSendManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
@@ -333,6 +360,43 @@ ParseResult PimReceiveManyOp::parse(OpAsmParser& parser, OperationState& result)
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimReceiveTensorOp::print(OpAsmPrinter& printer) {
|
||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||
printer << " into ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOperand(getOutputBuffer());
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getOutputBuffer().getType());
|
||||
printer << " -> ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand outputBuffer;
|
||||
Type outputBufferType;
|
||||
Type outputType;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
|
||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
||||
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|
||||
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|
||||
|| parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!sourceCoreIds.empty())
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
|
||||
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||
printer << " into ";
|
||||
|
||||
@@ -48,12 +48,32 @@ static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types
|
||||
return op->emitError() << kind << " values must all have the same type";
|
||||
if (firstIsTensor != isa<RankedTensorType>(type) || firstIsMemRef != isa<MemRefType>(type))
|
||||
return op->emitError() << kind << " values must all use the same shaped container kind";
|
||||
if (firstShapedType.getElementType() != shapedType.getElementType() || firstShapedType.getShape() != shapedType.getShape())
|
||||
if (firstShapedType.getElementType() != shapedType.getElementType()
|
||||
|| firstShapedType.getShape() != shapedType.getShape())
|
||||
return op->emitError() << kind << " values must all have the same shape and element type";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
|
||||
if (coreIds.empty())
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor or memref";
|
||||
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % static_cast<int64_t>(coreIds.size()) != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
|
||||
if (!coreBatchOp)
|
||||
@@ -61,9 +81,7 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
return coreBatchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op,
|
||||
ArrayRef<int32_t> coreIds,
|
||||
size_t valueCount) {
|
||||
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside pim.core_batch");
|
||||
@@ -109,7 +127,8 @@ LogicalResult PimMapOp::verify() {
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("body must have exactly one block argument");
|
||||
if (block.getArgument(0).getType() != inputType)
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), block.getArgument(0).getType(), inputType, "body block argument type must match input type")))
|
||||
return emitError("body block argument type must match input type");
|
||||
|
||||
auto yieldOp = dyn_cast_or_null<PimYieldOp>(block.getTerminator());
|
||||
@@ -117,7 +136,8 @@ LogicalResult PimMapOp::verify() {
|
||||
return emitError("body must terminate with pim.yield");
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return emitError("body yield must produce exactly one value");
|
||||
if (yieldOp.getOperand(0).getType() != outputType)
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), yieldOp.getOperand(0).getType(), outputType, "body yield type must match output type")))
|
||||
return emitError("body yield type must match output type");
|
||||
|
||||
return success();
|
||||
@@ -129,6 +149,10 @@ LogicalResult PimSendManyOp::verify() {
|
||||
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many");
|
||||
}
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult PimSendManyBatchOp::verify() {
|
||||
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
@@ -153,6 +177,14 @@ LogicalResult PimReceiveManyOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveTensorOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveManyBatchOp::verify() {
|
||||
if (getOutputBuffers().size() != getOutputs().size())
|
||||
return emitError("number of output buffers must match the number of outputs");
|
||||
|
||||
@@ -34,10 +34,8 @@ static Value materializeContiguousMemRef(Value memrefValue, Location loc, Rewrit
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static FailureOr<Value> getBufferOrValue(RewriterBase& rewriter,
|
||||
Value value,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) {
|
||||
static FailureOr<Value>
|
||||
getBufferOrValue(RewriterBase& rewriter, Value value, const BufferizationOptions& options, BufferizationState& state) {
|
||||
if (isa<BufferLikeType>(value.getType()))
|
||||
return value;
|
||||
return getBuffer(rewriter, value, options, state);
|
||||
@@ -205,13 +203,37 @@ struct ReceiveManyOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveM
|
||||
resultTypes.push_back(outputBufferOpt->getType());
|
||||
}
|
||||
|
||||
auto newOp = PimReceiveManyOp::create(
|
||||
rewriter, receiveOp.getLoc(), TypeRange(resultTypes), ValueRange(outputBuffers), receiveOp.getSourceCoreIdsAttr());
|
||||
auto newOp = PimReceiveManyOp::create(rewriter,
|
||||
receiveOp.getLoc(),
|
||||
TypeRange(resultTypes),
|
||||
ValueRange(outputBuffers),
|
||||
receiveOp.getSourceCoreIdsAttr());
|
||||
rewriter.replaceOp(receiveOp, newOp.getOutputs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveTensorOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorOpInterface, PimReceiveTensorOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto receiveOp = cast<PimReceiveTensorOp>(op);
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveTensorOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveManyBatchOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<ReceiveManyBatchOpInterface, PimReceiveManyBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
@@ -337,6 +359,30 @@ struct EmptyManyOpInterface : BufferizableOpInterface::ExternalModel<EmptyManyOp
|
||||
}
|
||||
};
|
||||
|
||||
struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel<SendTensorOpInterface, PimSendTensorOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto sendOp = cast<PimSendTensorOp>(op);
|
||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimSendTensorOp>(
|
||||
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, PimMapOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
@@ -349,23 +395,26 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, P
|
||||
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
|
||||
auto mapOp = cast<PimMapOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 || mapOp.getInputs().empty())
|
||||
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0
|
||||
|| mapOp.getInputs().empty())
|
||||
return {};
|
||||
|
||||
return {{&mapOp->getOpOperand(0), BufferRelation::Equivalent}};
|
||||
return {
|
||||
{&mapOp->getOpOperand(0), BufferRelation::Equivalent}
|
||||
};
|
||||
}
|
||||
|
||||
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
|
||||
|
||||
FailureOr<BufferLikeType>
|
||||
getBufferType(Operation* op,
|
||||
FailureOr<BufferLikeType> getBufferType(Operation* op,
|
||||
Value value,
|
||||
const BufferizationOptions& options,
|
||||
const BufferizationState& state,
|
||||
SmallVector<Value>& invocationStack) const {
|
||||
auto mapOp = cast<PimMapOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 || mapOp.getInputs().empty())
|
||||
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0
|
||||
|| mapOp.getInputs().empty())
|
||||
return failure();
|
||||
|
||||
auto inputType = dyn_cast<BufferLikeType>(mapOp.getInputs().front().getType());
|
||||
@@ -417,13 +466,9 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, P
|
||||
};
|
||||
|
||||
struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOpInterface, PimCoreBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return true;
|
||||
}
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return false;
|
||||
}
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
@@ -436,15 +481,14 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
||||
return {};
|
||||
|
||||
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
|
||||
return {{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}};
|
||||
return {
|
||||
{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}
|
||||
};
|
||||
}
|
||||
|
||||
bool isWritable(Operation* op, Value value, const AnalysisState& state) const {
|
||||
return false;
|
||||
}
|
||||
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
|
||||
|
||||
FailureOr<BufferLikeType>
|
||||
getBufferType(Operation* op,
|
||||
FailureOr<BufferLikeType> getBufferType(Operation* op,
|
||||
Value value,
|
||||
const BufferizationOptions& options,
|
||||
const BufferizationState& state,
|
||||
@@ -467,13 +511,11 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
||||
BufferizationState& state) const {
|
||||
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
||||
|
||||
bool alreadyBufferized = llvm::all_of(coreBatchOp.getWeights(), [](Value weight) {
|
||||
return isa<BufferLikeType>(weight.getType());
|
||||
}) && llvm::all_of(coreBatchOp.getInputs(), [](Value input) {
|
||||
return isa<BufferLikeType>(input.getType());
|
||||
}) && llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) {
|
||||
return isa<BufferLikeType>(arg.getType());
|
||||
});
|
||||
bool alreadyBufferized =
|
||||
llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
|
||||
&& llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa<BufferLikeType>(input.getType()); })
|
||||
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(),
|
||||
[](BlockArgument arg) { return isa<BufferLikeType>(arg.getType()); });
|
||||
if (alreadyBufferized)
|
||||
return success();
|
||||
|
||||
@@ -693,8 +735,10 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
|
||||
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||
PimReceiveManyOp::attachInterface<ReceiveManyOpInterface>(*ctx);
|
||||
PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx);
|
||||
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
|
||||
PimReceiveManyBatchOp::attachInterface<ReceiveManyBatchOpInterface>(*ctx);
|
||||
PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx);
|
||||
PimExtractRowsOp::attachInterface<ExtractRowsOpInterface>(*ctx);
|
||||
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
|
||||
Reference in New Issue
Block a user