cleanup unused channel operations and related logic
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -519,61 +519,12 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
|
|||||||
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
|
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
|
||||||
const StaticValueKnowledge& knowledge) const {
|
|
||||||
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
|
|
||||||
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveTensorOp.getOutputBuffer().getType()))
|
|
||||||
/ receiveTensorOp.getSourceCoreIds().size();
|
|
||||||
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
|
|
||||||
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimCodeGen::codeGenReceiveBatchOp(pim::PimReceiveBatchOp receiveOp,
|
|
||||||
unsigned lane,
|
|
||||||
const StaticValueKnowledge& knowledge) const {
|
|
||||||
emitCommunicationOp(
|
|
||||||
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreIds()[lane], receiveOp.getSize());
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimCodeGen::codeGenReceiveTensorBatchOp(pim::PimReceiveTensorBatchOp receiveOp,
|
|
||||||
ArrayRef<int32_t> laneCoreIds,
|
|
||||||
const StaticValueKnowledge& knowledge) const {
|
|
||||||
size_t outputAddr = addressOf(receiveOp.getOutputBuffer(), knowledge);
|
|
||||||
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveOp.getOutputBuffer().getType()))
|
|
||||||
/ laneCoreIds.size();
|
|
||||||
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(laneCoreIds))
|
|
||||||
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
||||||
auto targetCoreId = indexOf(sendOp.getTargetCoreId(), knowledge);
|
auto targetCoreId = indexOf(sendOp.getTargetCoreId(), knowledge);
|
||||||
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
|
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
|
||||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
|
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
|
||||||
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
|
|
||||||
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendTensorOp.getInput().getType()))
|
|
||||||
/ sendTensorOp.getTargetCoreIds().size();
|
|
||||||
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
|
|
||||||
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimCodeGen::codeGenSendBatchOp(pim::PimSendBatchOp sendOp,
|
|
||||||
unsigned lane,
|
|
||||||
const StaticValueKnowledge& knowledge) const {
|
|
||||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreIds()[lane], sendOp.getSize());
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimCodeGen::codeGenSendTensorBatchOp(pim::PimSendTensorBatchOp sendOp,
|
|
||||||
ArrayRef<int32_t> laneCoreIds,
|
|
||||||
const StaticValueKnowledge& knowledge) const {
|
|
||||||
size_t inputAddr = addressOf(sendOp.getInput(), knowledge);
|
|
||||||
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendOp.getInput().getType())) / laneCoreIds.size();
|
|
||||||
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(laneCoreIds))
|
|
||||||
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
|
||||||
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
|
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
|
||||||
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
|
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
|
||||||
@@ -902,13 +853,7 @@ enum class CompiledCoreOpKind : uint8_t {
|
|||||||
Store,
|
Store,
|
||||||
Lmv,
|
Lmv,
|
||||||
Receive,
|
Receive,
|
||||||
ReceiveBatch,
|
|
||||||
ReceiveTensor,
|
|
||||||
ReceiveTensorBatch,
|
|
||||||
Send,
|
Send,
|
||||||
SendBatch,
|
|
||||||
SendTensor,
|
|
||||||
SendTensorBatch,
|
|
||||||
Concat,
|
Concat,
|
||||||
Vmm,
|
Vmm,
|
||||||
Transpose,
|
Transpose,
|
||||||
@@ -952,20 +897,8 @@ static FailureOr<CompiledCoreOpKind> classifyCompiledCoreOpKind(Operation& op) {
|
|||||||
return CompiledCoreOpKind::Lmv;
|
return CompiledCoreOpKind::Lmv;
|
||||||
if (isa<pim::PimReceiveOp>(op))
|
if (isa<pim::PimReceiveOp>(op))
|
||||||
return CompiledCoreOpKind::Receive;
|
return CompiledCoreOpKind::Receive;
|
||||||
if (isa<pim::PimReceiveBatchOp>(op))
|
|
||||||
return CompiledCoreOpKind::ReceiveBatch;
|
|
||||||
if (isa<pim::PimReceiveTensorOp>(op))
|
|
||||||
return CompiledCoreOpKind::ReceiveTensor;
|
|
||||||
if (isa<pim::PimReceiveTensorBatchOp>(op))
|
|
||||||
return CompiledCoreOpKind::ReceiveTensorBatch;
|
|
||||||
if (isa<pim::PimSendOp>(op))
|
if (isa<pim::PimSendOp>(op))
|
||||||
return CompiledCoreOpKind::Send;
|
return CompiledCoreOpKind::Send;
|
||||||
if (isa<pim::PimSendBatchOp>(op))
|
|
||||||
return CompiledCoreOpKind::SendBatch;
|
|
||||||
if (isa<pim::PimSendTensorOp>(op))
|
|
||||||
return CompiledCoreOpKind::SendTensor;
|
|
||||||
if (isa<pim::PimSendTensorBatchOp>(op))
|
|
||||||
return CompiledCoreOpKind::SendTensorBatch;
|
|
||||||
if (isa<pim::PimConcatOp>(op))
|
if (isa<pim::PimConcatOp>(op))
|
||||||
return CompiledCoreOpKind::Concat;
|
return CompiledCoreOpKind::Concat;
|
||||||
if (isa<pim::PimVMMOp>(op))
|
if (isa<pim::PimVMMOp>(op))
|
||||||
@@ -1108,43 +1041,9 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
|
|||||||
case CompiledCoreOpKind::Receive:
|
case CompiledCoreOpKind::Receive:
|
||||||
coreCodeGen.codeGenReceiveOp(cast<pim::PimReceiveOp>(node.op), knowledge);
|
coreCodeGen.codeGenReceiveOp(cast<pim::PimReceiveOp>(node.op), knowledge);
|
||||||
break;
|
break;
|
||||||
case CompiledCoreOpKind::ReceiveBatch:
|
|
||||||
if (!batchLane)
|
|
||||||
return failure();
|
|
||||||
coreCodeGen.codeGenReceiveBatchOp(cast<pim::PimReceiveBatchOp>(node.op), *batchLane, knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::ReceiveTensor:
|
|
||||||
coreCodeGen.codeGenReceiveTensorOp(cast<pim::PimReceiveTensorOp>(node.op), knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::ReceiveTensorBatch:
|
|
||||||
if (!batchLane || !batchLaneCount)
|
|
||||||
return failure();
|
|
||||||
coreCodeGen.codeGenReceiveTensorBatchOp(cast<pim::PimReceiveTensorBatchOp>(node.op),
|
|
||||||
getLaneChunkCoreIds(cast<pim::PimReceiveTensorBatchOp>(node.op).getSourceCoreIds(),
|
|
||||||
*batchLaneCount,
|
|
||||||
*batchLane),
|
|
||||||
knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::Send:
|
case CompiledCoreOpKind::Send:
|
||||||
coreCodeGen.codeGenSendOp(cast<pim::PimSendOp>(node.op), knowledge);
|
coreCodeGen.codeGenSendOp(cast<pim::PimSendOp>(node.op), knowledge);
|
||||||
break;
|
break;
|
||||||
case CompiledCoreOpKind::SendBatch:
|
|
||||||
if (!batchLane)
|
|
||||||
return failure();
|
|
||||||
coreCodeGen.codeGenSendBatchOp(cast<pim::PimSendBatchOp>(node.op), *batchLane, knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::SendTensor:
|
|
||||||
coreCodeGen.codeGenSendTensorOp(cast<pim::PimSendTensorOp>(node.op), knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::SendTensorBatch:
|
|
||||||
if (!batchLane || !batchLaneCount)
|
|
||||||
return failure();
|
|
||||||
coreCodeGen.codeGenSendTensorBatchOp(cast<pim::PimSendTensorBatchOp>(node.op),
|
|
||||||
getLaneChunkCoreIds(cast<pim::PimSendTensorBatchOp>(node.op).getTargetCoreIds(),
|
|
||||||
*batchLaneCount,
|
|
||||||
*batchLane),
|
|
||||||
knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::Concat:
|
case CompiledCoreOpKind::Concat:
|
||||||
coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge);
|
coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge);
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -178,17 +178,7 @@ public:
|
|||||||
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
|
||||||
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
|
|
||||||
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
|
|
||||||
void codeGenReceiveBatchOp(pim::PimReceiveBatchOp receiveOp, unsigned lane, const StaticValueKnowledge& knowledge) const;
|
|
||||||
void codeGenReceiveTensorBatchOp(pim::PimReceiveTensorBatchOp receiveOp,
|
|
||||||
llvm::ArrayRef<int32_t> laneCoreIds,
|
|
||||||
const StaticValueKnowledge& knowledge) const;
|
|
||||||
void codeGenSendBatchOp(pim::PimSendBatchOp sendOp, unsigned lane, const StaticValueKnowledge& knowledge) const;
|
|
||||||
void codeGenSendTensorBatchOp(pim::PimSendTensorBatchOp sendOp,
|
|
||||||
llvm::ArrayRef<int32_t> laneCoreIds,
|
|
||||||
const StaticValueKnowledge& knowledge) const;
|
|
||||||
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
|
||||||
template <typename MVMTy>
|
template <typename MVMTy>
|
||||||
|
|||||||
@@ -18,27 +18,6 @@ using namespace onnx_mlir::pim;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
|
||||||
|
|
||||||
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
|
||||||
APInt constantValue;
|
|
||||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
|
||||||
return failure();
|
|
||||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
|
||||||
SmallVector<int32_t> constants;
|
|
||||||
constants.reserve(values.size());
|
|
||||||
for (Value value : values) {
|
|
||||||
FailureOr<int32_t> constantValue = getConstantI32Value(value);
|
|
||||||
if (failed(constantValue))
|
|
||||||
return failure();
|
|
||||||
constants.push_back(*constantValue);
|
|
||||||
}
|
|
||||||
return constants;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||||
return operandIndex == 2;
|
return operandIndex == 2;
|
||||||
@@ -62,43 +41,6 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
|
|||||||
return coreIds;
|
return coreIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
|
|
||||||
IRMapping& mapper,
|
|
||||||
IRRewriter& rewriter) {
|
|
||||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendTensorBatchOp.getTargetCoreIds());
|
|
||||||
if (failed(targetCoreIds))
|
|
||||||
return sendTensorBatchOp.emitOpError("expected constant targetCoreIds");
|
|
||||||
for (int32_t& targetCoreId : *targetCoreIds)
|
|
||||||
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
|
|
||||||
|
|
||||||
pim::PimSendTensorBatchOp::create(rewriter,
|
|
||||||
sendTensorBatchOp.getLoc(),
|
|
||||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
|
||||||
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
|
||||||
IRMapping& mapper,
|
|
||||||
IRRewriter& rewriter) {
|
|
||||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorBatchOp.getSourceCoreIds());
|
|
||||||
if (failed(sourceCoreIds))
|
|
||||||
return receiveTensorBatchOp.emitOpError("expected constant sourceCoreIds");
|
|
||||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
|
||||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
|
||||||
|
|
||||||
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
|
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
|
|
||||||
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
|
|
||||||
receiveTensorBatchOp.getLoc(),
|
|
||||||
outputBuffer.getType(),
|
|
||||||
outputBuffer,
|
|
||||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
|
||||||
.getOutput();
|
|
||||||
mapper.map(receiveTensorBatchOp.getOutput(), received);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
||||||
if (!result.hasOneUse())
|
if (!result.hasOneUse())
|
||||||
return failure();
|
return failure();
|
||||||
@@ -304,51 +246,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
|
||||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
|
|
||||||
if (failed(targetCoreIds))
|
|
||||||
return sendBatchOp.emitOpError("expected constant targetCoreIds");
|
|
||||||
for (int32_t& targetCoreId : *targetCoreIds)
|
|
||||||
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
|
|
||||||
pim::PimSendBatchOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
mapper.lookup(sendBatchOp.getInput()),
|
|
||||||
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
|
|
||||||
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
|
|
||||||
if (failed(lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter)))
|
|
||||||
return failure();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
|
||||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveBatchOp.getSourceCoreIds());
|
|
||||||
if (failed(sourceCoreIds))
|
|
||||||
return receiveBatchOp.emitOpError("expected constant sourceCoreIds");
|
|
||||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
|
||||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
|
||||||
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
|
||||||
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
outputBuffer.getType(),
|
|
||||||
outputBuffer,
|
|
||||||
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
|
|
||||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
|
||||||
.getOutput();
|
|
||||||
mapper.map(receiveBatchOp.getOutput(), received);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
|
|
||||||
if (failed(lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter)))
|
|
||||||
return failure();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
||||||
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||||
Operation* cloned = rewriter.clone(op, mapper);
|
Operation* cloned = rewriter.clone(op, mapper);
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
@@ -11,20 +9,6 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
|
|
||||||
|
|
||||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
|
||||||
SmallVector<int32_t> constants;
|
|
||||||
constants.reserve(values.size());
|
|
||||||
for (Value value : values) {
|
|
||||||
APInt constantValue;
|
|
||||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
|
||||||
return failure();
|
|
||||||
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
|
|
||||||
}
|
|
||||||
return constants;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
@@ -59,42 +43,6 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
|
|
||||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(op.getTargetCoreIds());
|
|
||||||
if (failed(targetCoreIds))
|
|
||||||
return rewriter.notifyMatchFailure(op, "expected constant targetCoreIds");
|
|
||||||
for (int32_t& targetCoreId : *targetCoreIds)
|
|
||||||
targetCoreId = toPimCoreId(targetCoreId);
|
|
||||||
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
|
|
||||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(op.getSourceCoreIds());
|
|
||||||
if (failed(sourceCoreIds))
|
|
||||||
return rewriter.notifyMatchFailure(op, "expected constant sourceCoreIds");
|
|
||||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
|
||||||
sourceCoreId = toPimCoreId(sourceCoreId);
|
|
||||||
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
|
||||||
Value outputBuffer =
|
|
||||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
|
||||||
Value received =
|
|
||||||
pim::PimReceiveTensorOp::create(
|
|
||||||
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
|
||||||
.getOutput();
|
|
||||||
rewriter.replaceOp(op, received);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
|
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
@@ -137,12 +85,7 @@ struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
|
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
|
||||||
patterns.add<ChannelSendLowering,
|
patterns.add<ChannelSendLowering, ChannelReceiveLowering, ExtractRowsLowering, ConcatLowering>(patterns.getContext());
|
||||||
ChannelReceiveLowering,
|
|
||||||
ChannelSendTensorLowering,
|
|
||||||
ChannelReceiveTensorLowering,
|
|
||||||
ExtractRowsLowering,
|
|
||||||
ConcatLowering>(patterns.getContext());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -53,20 +53,6 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
|
||||||
|
|
||||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
|
||||||
SmallVector<int32_t> constants;
|
|
||||||
constants.reserve(values.size());
|
|
||||||
for (Value value : values) {
|
|
||||||
APInt constantValue;
|
|
||||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
|
||||||
return failure();
|
|
||||||
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
|
|
||||||
}
|
|
||||||
return constants;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
||||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
||||||
@@ -186,25 +172,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp());
|
|
||||||
if (receiveTensorOp && !blockArg->use_empty()) {
|
|
||||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds());
|
|
||||||
if (failed(sourceCoreIds))
|
|
||||||
return receiveTensorOp.emitOpError("expected constant sourceCoreIds");
|
|
||||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
|
||||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
|
||||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
|
||||||
auto outputType = cast<ShapedType>(blockArg->getType());
|
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType);
|
|
||||||
Value received = PimReceiveTensorOp::create(rewriter,
|
|
||||||
receiveTensorOp.getLoc(),
|
|
||||||
outputBuffer.getType(),
|
|
||||||
outputBuffer,
|
|
||||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
|
||||||
.getOutput();
|
|
||||||
blockArg->replaceAllUsesWith(received);
|
|
||||||
markOpToRemove(receiveTensorOp);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
||||||
|
|||||||
@@ -607,8 +607,6 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
|
|
||||||
markOpToRemove(receiveTensorOp);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||||
|
|||||||
@@ -156,11 +156,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
BuiltinDialect>();
|
BuiltinDialect>();
|
||||||
target.addLegalOp<spatial::SpatConcatOp,
|
target.addLegalOp<spatial::SpatConcatOp,
|
||||||
spatial::SpatChannelReceiveOp,
|
spatial::SpatChannelReceiveOp,
|
||||||
spatial::SpatChannelReceiveTensorOp,
|
|
||||||
spatial::SpatChannelReceiveTensorBatchOp,
|
|
||||||
spatial::SpatChannelSendOp,
|
spatial::SpatChannelSendOp,
|
||||||
spatial::SpatChannelSendTensorOp,
|
|
||||||
spatial::SpatChannelSendTensorBatchOp,
|
|
||||||
spatial::SpatExtractRowsOp>();
|
spatial::SpatExtractRowsOp>();
|
||||||
|
|
||||||
RewritePatternSet initialPatterns(ctx);
|
RewritePatternSet initialPatterns(ctx);
|
||||||
@@ -234,11 +230,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
BuiltinDialect>();
|
BuiltinDialect>();
|
||||||
coreBodyTarget.addLegalOp<spatial::SpatConcatOp,
|
coreBodyTarget.addLegalOp<spatial::SpatConcatOp,
|
||||||
spatial::SpatChannelReceiveOp,
|
spatial::SpatChannelReceiveOp,
|
||||||
spatial::SpatChannelReceiveTensorOp,
|
|
||||||
spatial::SpatChannelReceiveTensorBatchOp,
|
|
||||||
spatial::SpatChannelSendOp,
|
spatial::SpatChannelSendOp,
|
||||||
spatial::SpatChannelSendTensorOp,
|
|
||||||
spatial::SpatChannelSendTensorBatchOp,
|
|
||||||
spatial::SpatExtractRowsOp>();
|
spatial::SpatExtractRowsOp>();
|
||||||
|
|
||||||
SmallVector<pim::PimCoreOp> coreOps;
|
SmallVector<pim::PimCoreOp> coreOps;
|
||||||
@@ -282,9 +274,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
communicationTarget.addLegalOp<ModuleOp>();
|
communicationTarget.addLegalOp<ModuleOp>();
|
||||||
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
|
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
|
||||||
spatial::SpatChannelReceiveOp,
|
spatial::SpatChannelReceiveOp,
|
||||||
spatial::SpatChannelReceiveTensorOp,
|
|
||||||
spatial::SpatChannelSendOp,
|
spatial::SpatChannelSendOp,
|
||||||
spatial::SpatChannelSendTensorOp,
|
|
||||||
spatial::SpatExtractRowsOp>();
|
spatial::SpatExtractRowsOp>();
|
||||||
|
|
||||||
RewritePatternSet communicationPatterns(ctx);
|
RewritePatternSet communicationPatterns(ctx);
|
||||||
|
|||||||
@@ -102,42 +102,6 @@ def PimSendOp : PimOp<"send", []> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimSendTensorOp : PimOp<"send_tensor", []> {
|
|
||||||
let summary = "Send equal contiguous chunks of one tensor to target cores";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor:$input,
|
|
||||||
DenseI32ArrayAttr:$targetCoreIds
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimSendBatchOp : PimOp<"send_batch", []> {
|
|
||||||
let summary = "Send a per-lane tensor to target cores from a batched core";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor:$input,
|
|
||||||
I32Attr:$size,
|
|
||||||
DenseI32ArrayAttr:$targetCoreIds
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimSendTensorBatchOp : PimOp<"send_tensor_batch", []> {
|
|
||||||
let summary = "Send equal contiguous chunks of one per-lane tensor from a batched core";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor:$input,
|
|
||||||
DenseI32ArrayAttr:$targetCoreIds
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||||
let summary = "Receive a tensor from another core";
|
let summary = "Receive a tensor from another core";
|
||||||
|
|
||||||
@@ -162,72 +126,6 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimReceiveTensorOp : PimOp<"receive_tensor", [DestinationStyleOpInterface]> {
|
|
||||||
let summary = "Receive equal contiguous chunks from source cores into one tensor";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor:$outputBuffer,
|
|
||||||
DenseI32ArrayAttr:$sourceCoreIds
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutputBufferMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
|
|
||||||
let summary = "Receive per-lane tensors from source cores into a batched core";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor:$outputBuffer,
|
|
||||||
I32Attr:$size,
|
|
||||||
DenseI32ArrayAttr:$sourceCoreIds
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutputBufferMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimReceiveTensorBatchOp : PimOp<"receive_tensor_batch", [DestinationStyleOpInterface]> {
|
|
||||||
let summary = "Receive equal contiguous chunks into one per-lane tensor inside a batched core";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor:$outputBuffer,
|
|
||||||
DenseI32ArrayAttr:$sourceCoreIds
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutputBufferMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||||
let summary = "Copy a memory region from host memory into device memory";
|
let summary = "Copy a memory region from host memory into device memory";
|
||||||
|
|
||||||
|
|||||||
@@ -28,34 +28,6 @@ static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred,
|
|||||||
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
|
||||||
printer << "(";
|
|
||||||
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
|
||||||
if (index != 0)
|
|
||||||
printer << ", ";
|
|
||||||
printer.printOperand(argument);
|
|
||||||
}
|
|
||||||
printer << ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
|
||||||
if (parser.parseLParen())
|
|
||||||
return failure();
|
|
||||||
if (succeeded(parser.parseOptionalRParen()))
|
|
||||||
return success();
|
|
||||||
|
|
||||||
OpAsmParser::Argument argument;
|
|
||||||
if (parser.parseArgument(argument))
|
|
||||||
return failure();
|
|
||||||
arguments.push_back(argument);
|
|
||||||
while (succeeded(parser.parseOptionalComma())) {
|
|
||||||
if (parser.parseArgument(argument))
|
|
||||||
return failure();
|
|
||||||
arguments.push_back(argument);
|
|
||||||
}
|
|
||||||
return parser.parseRParen();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void
|
static void
|
||||||
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||||
printCompressedValueList(printer, arguments, delimiter);
|
printCompressedValueList(printer, arguments, delimiter);
|
||||||
@@ -98,12 +70,6 @@ static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<i
|
|||||||
printCompressedIntegerList(printer, coreIds);
|
printCompressedIntegerList(printer, coreIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keyword, SmallVectorImpl<int32_t>& coreIds) {
|
|
||||||
if (failed(parser.parseOptionalKeyword(keyword)))
|
|
||||||
return success();
|
|
||||||
return parseCompressedIntegerList(parser, coreIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void PimCoreOp::print(OpAsmPrinter& printer) {
|
void PimCoreOp::print(OpAsmPrinter& printer) {
|
||||||
@@ -295,198 +261,6 @@ ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
|
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimSendBatchOp::print(OpAsmPrinter& printer) {
|
|
||||||
printer << " ";
|
|
||||||
printer.printOperand(getInput());
|
|
||||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
|
||||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
|
||||||
printer << " : ";
|
|
||||||
printer.printType(getInput().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
||||||
OpAsmParser::UnresolvedOperand input;
|
|
||||||
Type inputType;
|
|
||||||
SmallVector<int32_t> targetCoreIds;
|
|
||||||
|
|
||||||
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
|
||||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
|
||||||
if (!targetCoreIds.empty())
|
|
||||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
|
||||||
|
|
||||||
return parser.resolveOperand(input, inputType, result.operands);
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimSendTensorBatchOp::print(OpAsmPrinter& printer) {
|
|
||||||
printer << " ";
|
|
||||||
printer.printOperand(getInput());
|
|
||||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
|
||||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
|
||||||
printer << " : ";
|
|
||||||
printer.printType(getInput().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult PimSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
||||||
OpAsmParser::UnresolvedOperand input;
|
|
||||||
Type inputType;
|
|
||||||
SmallVector<int32_t> targetCoreIds;
|
|
||||||
|
|
||||||
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
|
||||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
|
||||||
if (!targetCoreIds.empty())
|
|
||||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
|
||||||
|
|
||||||
return parser.resolveOperand(input, inputType, result.operands);
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimSendTensorOp::print(OpAsmPrinter& printer) {
|
|
||||||
printer << " ";
|
|
||||||
printer.printOperand(getInput());
|
|
||||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
|
||||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
|
||||||
printer << " : ";
|
|
||||||
printer.printType(getInput().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
||||||
OpAsmParser::UnresolvedOperand input;
|
|
||||||
Type inputType;
|
|
||||||
SmallVector<int32_t> targetCoreIds;
|
|
||||||
|
|
||||||
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
|
||||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
|
||||||
if (!targetCoreIds.empty())
|
|
||||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
|
||||||
|
|
||||||
return parser.resolveOperand(input, inputType, result.operands);
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimReceiveTensorOp::print(OpAsmPrinter& printer) {
|
|
||||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
|
||||||
printer << " into ";
|
|
||||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
|
||||||
printer.printOperand(getOutputBuffer());
|
|
||||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
|
||||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
|
||||||
printer << " : ";
|
|
||||||
printer.printType(getOutputBuffer().getType());
|
|
||||||
printer << " -> ";
|
|
||||||
printer.printType(getOutput().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult PimReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
||||||
OpAsmParser::UnresolvedOperand outputBuffer;
|
|
||||||
Type outputBufferType;
|
|
||||||
Type outputType;
|
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
|
||||||
|
|
||||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
|
||||||
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|
|
||||||
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|
|
||||||
|| parser.parseType(outputType))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
|
||||||
if (!sourceCoreIds.empty())
|
|
||||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
|
||||||
|
|
||||||
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
|
|
||||||
return failure();
|
|
||||||
result.addTypes(outputType);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimReceiveBatchOp::print(OpAsmPrinter& printer) {
|
|
||||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
|
||||||
printer << " into ";
|
|
||||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
|
||||||
printer.printOperand(getOutputBuffer());
|
|
||||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
|
||||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
|
||||||
printer << " : ";
|
|
||||||
printer.printType(getOutputBuffer().getType());
|
|
||||||
printer << " -> ";
|
|
||||||
printer.printType(getOutput().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
||||||
OpAsmParser::UnresolvedOperand outputBuffer;
|
|
||||||
Type outputBufferType;
|
|
||||||
Type outputType;
|
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
|
||||||
|
|
||||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
|
||||||
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|
|
||||||
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|
|
||||||
|| parser.parseType(outputType))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
|
||||||
if (!sourceCoreIds.empty())
|
|
||||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
|
||||||
|
|
||||||
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
|
|
||||||
return failure();
|
|
||||||
result.addTypes(outputType);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimReceiveTensorBatchOp::print(OpAsmPrinter& printer) {
|
|
||||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
|
||||||
printer << " into ";
|
|
||||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
|
||||||
printer.printOperand(getOutputBuffer());
|
|
||||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
|
||||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
|
||||||
printer << " : ";
|
|
||||||
printer.printType(getOutputBuffer().getType());
|
|
||||||
printer << " -> ";
|
|
||||||
printer.printType(getOutput().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult PimReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
||||||
OpAsmParser::UnresolvedOperand outputBuffer;
|
|
||||||
Type outputBufferType;
|
|
||||||
Type outputType;
|
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
|
||||||
|
|
||||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
|
||||||
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|
|
||||||
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|
|
||||||
|| parser.parseType(outputType))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
|
||||||
if (!sourceCoreIds.empty())
|
|
||||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
|
||||||
|
|
||||||
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
|
|
||||||
return failure();
|
|
||||||
result.addTypes(outputType);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimConcatOp::print(OpAsmPrinter& printer) {
|
void PimConcatOp::print(OpAsmPrinter& printer) {
|
||||||
printer << " axis " << getAxis() << " ";
|
printer << " axis " << getAxis() << " ";
|
||||||
printCompressedValueSequence(printer, getInputs());
|
printCompressedValueSequence(printer, getInputs());
|
||||||
|
|||||||
@@ -90,56 +90,6 @@ static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type r
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
|
|
||||||
if (coreIds.empty())
|
|
||||||
return op->emitError() << kind << " must carry at least one chunk";
|
|
||||||
|
|
||||||
auto shapedType = dyn_cast<ShapedType>(type);
|
|
||||||
if (!shapedType || !shapedType.hasStaticShape())
|
|
||||||
return op->emitError() << kind << " requires a static shaped tensor or memref";
|
|
||||||
|
|
||||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
|
||||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
|
||||||
return op->emitError() << kind << " requires byte-sized elements";
|
|
||||||
|
|
||||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
|
||||||
if (totalBytes % static_cast<int64_t>(coreIds.size()) != 0)
|
|
||||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids";
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult
|
|
||||||
verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
|
|
||||||
if (coreIds.empty())
|
|
||||||
return op->emitError() << kind << " must carry at least one chunk";
|
|
||||||
|
|
||||||
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
|
|
||||||
if (!coreBatchOp)
|
|
||||||
return op->emitError() << kind << " must be nested inside pim.core_batch";
|
|
||||||
|
|
||||||
int32_t laneCount = coreBatchOp.getLaneCount();
|
|
||||||
if (laneCount <= 0)
|
|
||||||
return op->emitError() << kind << " requires a positive parent laneCount";
|
|
||||||
if (coreIds.size() % static_cast<size_t>(laneCount) != 0)
|
|
||||||
return op->emitError() << kind << " core id count must be divisible by the parent laneCount";
|
|
||||||
|
|
||||||
auto shapedType = dyn_cast<ShapedType>(type);
|
|
||||||
if (!shapedType || !shapedType.hasStaticShape())
|
|
||||||
return op->emitError() << kind << " requires a static shaped tensor or memref";
|
|
||||||
|
|
||||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
|
||||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
|
||||||
return op->emitError() << kind << " requires byte-sized elements";
|
|
||||||
|
|
||||||
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
|
|
||||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
|
||||||
if (totalBytes % chunkCount != 0)
|
|
||||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
|
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
|
||||||
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
||||||
if (!shapedType)
|
if (!shapedType)
|
||||||
@@ -177,31 +127,6 @@ LogicalResult PimCoreBatchOp::verify() {
|
|||||||
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
|
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult PimSendTensorOp::verify() {
|
|
||||||
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult PimSendTensorBatchOp::verify() {
|
|
||||||
return verifyTensorBatchCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor_batch");
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult PimReceiveTensorOp::verify() {
|
|
||||||
if (failed(verifyCompatibleShapedTypes(
|
|
||||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult PimReceiveTensorBatchOp::verify() {
|
|
||||||
if (failed(verifyCompatibleShapedTypes(
|
|
||||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
return verifyTensorBatchCommunication(
|
|
||||||
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult PimVMMOp::verify() {
|
LogicalResult PimVMMOp::verify() {
|
||||||
if (failed(verifyCompatibleShapedTypes(
|
if (failed(verifyCompatibleShapedTypes(
|
||||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||||
|
|||||||
@@ -157,72 +157,6 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveBatchOpInterface, PimReceiveBatchOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto receiveOp = cast<PimReceiveBatchOp>(op);
|
|
||||||
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
|
|
||||||
if (failed(outputBufferOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimReceiveBatchOp>(rewriter,
|
|
||||||
op,
|
|
||||||
outputBufferOpt->getType(),
|
|
||||||
*outputBufferOpt,
|
|
||||||
receiveOp.getSizeAttr(),
|
|
||||||
receiveOp.getSourceCoreIdsAttr());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ReceiveTensorOpInterface
|
|
||||||
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorOpInterface, PimReceiveTensorOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto receiveOp = cast<PimReceiveTensorOp>(op);
|
|
||||||
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
|
|
||||||
if (failed(outputBufferOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimReceiveTensorOp>(
|
|
||||||
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ReceiveTensorBatchOpInterface
|
|
||||||
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorBatchOpInterface, PimReceiveTensorBatchOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto receiveOp = cast<PimReceiveTensorBatchOp>(op);
|
|
||||||
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
|
|
||||||
if (failed(outputBufferOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimReceiveTensorBatchOp>(
|
|
||||||
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInterface, PimConcatOp> {
|
struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInterface, PimConcatOp> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
@@ -252,30 +186,6 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel<SendTensorOpInterface, PimSendTensorOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
|
||||||
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto sendOp = cast<PimSendTensorOp>(op);
|
|
||||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
|
||||||
if (failed(inputOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimSendTensorOp>(
|
|
||||||
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface, PimSendOp> {
|
struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface, PimSendOp> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||||
|
|
||||||
@@ -303,58 +213,6 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SendBatchOpInterface : BufferizableOpInterface::ExternalModel<SendBatchOpInterface, PimSendBatchOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
|
||||||
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto sendOp = cast<PimSendBatchOp>(op);
|
|
||||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
|
||||||
if (failed(inputOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimSendBatchOp>(rewriter,
|
|
||||||
op,
|
|
||||||
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
|
|
||||||
sendOp.getSizeAttr(),
|
|
||||||
sendOp.getTargetCoreIdsAttr());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct SendTensorBatchOpInterface
|
|
||||||
: BufferizableOpInterface::ExternalModel<SendTensorBatchOpInterface, PimSendTensorBatchOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
|
||||||
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto sendOp = cast<PimSendTensorBatchOp>(op);
|
|
||||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
|
||||||
if (failed(inputOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimSendTensorBatchOp>(
|
|
||||||
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface, PimCoreOp> {
|
struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface, PimCoreOp> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||||
|
|
||||||
@@ -699,13 +557,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
|||||||
PimCoreOp::attachInterface<CoreOpInterface>(*ctx);
|
PimCoreOp::attachInterface<CoreOpInterface>(*ctx);
|
||||||
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
|
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
|
||||||
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||||
PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx);
|
|
||||||
PimReceiveTensorBatchOp::attachInterface<ReceiveTensorBatchOpInterface>(*ctx);
|
|
||||||
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
|
|
||||||
PimSendOp::attachInterface<SendOpInterface>(*ctx);
|
PimSendOp::attachInterface<SendOpInterface>(*ctx);
|
||||||
PimSendBatchOp::attachInterface<SendBatchOpInterface>(*ctx);
|
|
||||||
PimSendTensorBatchOp::attachInterface<SendTensorBatchOpInterface>(*ctx);
|
|
||||||
PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx);
|
|
||||||
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
|
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
|
||||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||||
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
||||||
|
|||||||
@@ -194,111 +194,6 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> {
|
|
||||||
let summary = "Send equal contiguous chunks of one tensor through logical channels";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
Variadic<Index>:$channelIds,
|
|
||||||
Variadic<Index>:$sourceCoreIds,
|
|
||||||
Variadic<Index>:$targetCoreIds,
|
|
||||||
SpatTensor:$input
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> {
|
|
||||||
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
Variadic<Index>:$channelIds,
|
|
||||||
Variadic<Index>:$sourceCoreIds,
|
|
||||||
Variadic<Index>:$targetCoreIds
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> {
|
|
||||||
let summary = "Send per-lane tensors through logical channels in a batch body";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
Variadic<Index>:$channelIds,
|
|
||||||
Variadic<Index>:$sourceCoreIds,
|
|
||||||
Variadic<Index>:$targetCoreIds,
|
|
||||||
SpatTensor:$input
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> {
|
|
||||||
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
Variadic<Index>:$channelIds,
|
|
||||||
Variadic<Index>:$sourceCoreIds,
|
|
||||||
Variadic<Index>:$targetCoreIds,
|
|
||||||
SpatTensor:$input
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> {
|
|
||||||
let summary = "Receive a per-lane tensor through logical channels in a batch body";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
Variadic<Index>:$channelIds,
|
|
||||||
Variadic<Index>:$sourceCoreIds,
|
|
||||||
Variadic<Index>:$targetCoreIds
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> {
|
|
||||||
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
Variadic<Index>:$channelIds,
|
|
||||||
Variadic<Index>:$sourceCoreIds,
|
|
||||||
Variadic<Index>:$targetCoreIds
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Math
|
// Math
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|||||||
@@ -95,13 +95,6 @@ static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
|
|||||||
return shapedType.getShape();
|
return shapedType.getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
|
||||||
auto batchOp = op->getParentOfType<SpatComputeBatch>();
|
|
||||||
if (!batchOp)
|
|
||||||
return failure();
|
|
||||||
return batchOp.getLaneCount();
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
||||||
if (batchOp.getNumResults() == 0)
|
if (batchOp.getNumResults() == 0)
|
||||||
return false;
|
return false;
|
||||||
@@ -233,68 +226,6 @@ static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::Paralle
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyTensorChannelSizes(
|
|
||||||
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
|
||||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
|
||||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
|
||||||
if (channelCount == 0)
|
|
||||||
return op->emitError() << kind << " must carry at least one chunk";
|
|
||||||
|
|
||||||
auto shapedType = dyn_cast<ShapedType>(type);
|
|
||||||
if (!shapedType || !shapedType.hasStaticShape())
|
|
||||||
return op->emitError() << kind << " requires a static shaped tensor";
|
|
||||||
|
|
||||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
|
||||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
|
||||||
return op->emitError() << kind << " requires byte-sized elements";
|
|
||||||
|
|
||||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
|
||||||
if (totalBytes % static_cast<int64_t>(channelCount) != 0)
|
|
||||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult
|
|
||||||
verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
|
|
||||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
|
||||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
|
||||||
|
|
||||||
auto laneCount = getParentBatchLaneCount(op);
|
|
||||||
if (failed(laneCount))
|
|
||||||
return op->emitError("must be nested inside spat.compute_batch");
|
|
||||||
if (channelCount != static_cast<size_t>(*laneCount))
|
|
||||||
return op->emitError("channel metadata length must match parent laneCount");
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verifyTensorBatchChannelSizes(
|
|
||||||
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
|
||||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
|
||||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
|
||||||
|
|
||||||
auto laneCount = getParentBatchLaneCount(op);
|
|
||||||
if (failed(laneCount))
|
|
||||||
return op->emitError("must be nested inside spat.compute_batch");
|
|
||||||
if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
|
|
||||||
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
|
|
||||||
|
|
||||||
auto shapedType = dyn_cast<ShapedType>(type);
|
|
||||||
if (!shapedType || !shapedType.hasStaticShape())
|
|
||||||
return op->emitError() << kind << " requires a static shaped tensor";
|
|
||||||
|
|
||||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
|
||||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
|
||||||
return op->emitError() << kind << " requires byte-sized elements";
|
|
||||||
|
|
||||||
int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
|
|
||||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
|
||||||
if (totalBytes % chunkCount != 0)
|
|
||||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static Region* getParentRegion(Value value) {
|
static Region* getParentRegion(Value value) {
|
||||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||||
return blockArg.getOwner()->getParent();
|
return blockArg.getOwner()->getParent();
|
||||||
@@ -564,52 +495,6 @@ LogicalResult SpatCompute::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatChannelSendTensorOp::verify() {
|
|
||||||
return verifyTensorChannelSizes(getOperation(),
|
|
||||||
getInput().getType(),
|
|
||||||
getChannelIds().size(),
|
|
||||||
getSourceCoreIds().size(),
|
|
||||||
getTargetCoreIds().size(),
|
|
||||||
"channel_send_tensor");
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult SpatChannelReceiveTensorOp::verify() {
|
|
||||||
return verifyTensorChannelSizes(getOperation(),
|
|
||||||
getOutput().getType(),
|
|
||||||
getChannelIds().size(),
|
|
||||||
getSourceCoreIds().size(),
|
|
||||||
getTargetCoreIds().size(),
|
|
||||||
"channel_receive_tensor");
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult SpatChannelSendBatchOp::verify() {
|
|
||||||
return verifyBatchChannelSizes(
|
|
||||||
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult SpatChannelSendTensorBatchOp::verify() {
|
|
||||||
return verifyTensorBatchChannelSizes(getOperation(),
|
|
||||||
getInput().getType(),
|
|
||||||
getChannelIds().size(),
|
|
||||||
getSourceCoreIds().size(),
|
|
||||||
getTargetCoreIds().size(),
|
|
||||||
"channel_send_tensor_batch");
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
|
||||||
return verifyBatchChannelSizes(
|
|
||||||
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
|
|
||||||
return verifyTensorBatchChannelSizes(getOperation(),
|
|
||||||
getOutput().getType(),
|
|
||||||
getChannelIds().size(),
|
|
||||||
getSourceCoreIds().size(),
|
|
||||||
getTargetCoreIds().size(),
|
|
||||||
"channel_receive_tensor_batch");
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult SpatComputeBatch::verify() {
|
LogicalResult SpatComputeBatch::verify() {
|
||||||
int32_t count = getLaneCount();
|
int32_t count = getLaneCount();
|
||||||
if (count <= 0)
|
if (count <= 0)
|
||||||
|
|||||||
@@ -79,8 +79,6 @@ struct MergeIrCounts {
|
|||||||
uint64_t topLevelComputeBatchCount = 0;
|
uint64_t topLevelComputeBatchCount = 0;
|
||||||
uint64_t scalarChannelSendCount = 0;
|
uint64_t scalarChannelSendCount = 0;
|
||||||
uint64_t scalarChannelReceiveCount = 0;
|
uint64_t scalarChannelReceiveCount = 0;
|
||||||
uint64_t tensorChannelSendCount = 0;
|
|
||||||
uint64_t tensorChannelReceiveCount = 0;
|
|
||||||
uint64_t wvmmCount = 0;
|
uint64_t wvmmCount = 0;
|
||||||
uint64_t vaddCount = 0;
|
uint64_t vaddCount = 0;
|
||||||
uint64_t scfForCount = 0;
|
uint64_t scfForCount = 0;
|
||||||
@@ -95,10 +93,6 @@ MergeIrCounts collectMergeIrCounts(func::FuncOp funcOp) {
|
|||||||
++counts.scalarChannelSendCount;
|
++counts.scalarChannelSendCount;
|
||||||
else if (isa<spatial::SpatChannelReceiveOp>(nestedOp))
|
else if (isa<spatial::SpatChannelReceiveOp>(nestedOp))
|
||||||
++counts.scalarChannelReceiveCount;
|
++counts.scalarChannelReceiveCount;
|
||||||
else if (isa<spatial::SpatChannelSendTensorOp, spatial::SpatChannelSendTensorBatchOp>(nestedOp))
|
|
||||||
++counts.tensorChannelSendCount;
|
|
||||||
else if (isa<spatial::SpatChannelReceiveTensorOp, spatial::SpatChannelReceiveTensorBatchOp>(nestedOp))
|
|
||||||
++counts.tensorChannelReceiveCount;
|
|
||||||
else if (isa<spatial::SpatVMMOp>(nestedOp))
|
else if (isa<spatial::SpatVMMOp>(nestedOp))
|
||||||
++counts.wvmmCount;
|
++counts.wvmmCount;
|
||||||
else if (isa<spatial::SpatVAddOp>(nestedOp))
|
else if (isa<spatial::SpatVAddOp>(nestedOp))
|
||||||
@@ -130,9 +124,8 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
|
|||||||
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
|
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
|
||||||
<< " scalar_send=" << counts.scalarChannelSendCount
|
<< " scalar_send=" << counts.scalarChannelSendCount
|
||||||
<< " scalar_recv=" << counts.scalarChannelReceiveCount
|
<< " scalar_recv=" << counts.scalarChannelReceiveCount
|
||||||
<< " tensor_send=" << counts.tensorChannelSendCount
|
<< " wvmm=" << counts.wvmmCount << " vadd=" << counts.vaddCount
|
||||||
<< " tensor_recv=" << counts.tensorChannelReceiveCount << " wvmm=" << counts.wvmmCount
|
<< " scf_for=" << counts.scfForCount << "\n";
|
||||||
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||||
|
|||||||
@@ -150,13 +150,7 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
|
|||||||
pim::PimMemCopyDevToHostOp,
|
pim::PimMemCopyDevToHostOp,
|
||||||
pim::PimMemCopyOp,
|
pim::PimMemCopyOp,
|
||||||
pim::PimReceiveOp,
|
pim::PimReceiveOp,
|
||||||
pim::PimReceiveBatchOp,
|
|
||||||
pim::PimReceiveTensorOp,
|
|
||||||
pim::PimReceiveTensorBatchOp,
|
|
||||||
pim::PimSendOp,
|
pim::PimSendOp,
|
||||||
pim::PimSendBatchOp,
|
|
||||||
pim::PimSendTensorOp,
|
|
||||||
pim::PimSendTensorBatchOp,
|
|
||||||
pim::PimConcatOp,
|
pim::PimConcatOp,
|
||||||
pim::PimVMMOp,
|
pim::PimVMMOp,
|
||||||
pim::PimTransposeOp,
|
pim::PimTransposeOp,
|
||||||
@@ -173,18 +167,6 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
|
|||||||
memref::GetGlobalOp>(op);
|
memref::GetGlobalOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<ShapedType> getStaticByteSizedShapedType(Type type) {
|
|
||||||
auto shapedType = dyn_cast<ShapedType>(type);
|
|
||||||
if (!shapedType || !shapedType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
|
||||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
return shapedType;
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verifyBatchOpSemantics(Operation& op,
|
static LogicalResult verifyBatchOpSemantics(Operation& op,
|
||||||
const StaticValueKnowledge& knowledge,
|
const StaticValueKnowledge& knowledge,
|
||||||
pim::CappedDiagnosticReporter& diagnostics) {
|
pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
@@ -203,73 +185,6 @@ static LogicalResult verifyBatchOpSemantics(Operation& op,
|
|||||||
return success(!hasFailure);
|
return success(!hasFailure);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
|
||||||
if (sendBatchOp.getTargetCoreIds().size() != static_cast<size_t>(sendBatchOp->getParentOfType<pim::PimCoreBatchOp>()
|
|
||||||
.getLaneCount())) {
|
|
||||||
reportFailure([](Operation* illegalOp) {
|
|
||||||
illegalOp->emitOpError("targetCoreIds size must match parent laneCount");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return success(!hasFailure);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
|
||||||
if (receiveBatchOp.getSourceCoreIds().size()
|
|
||||||
!= static_cast<size_t>(receiveBatchOp->getParentOfType<pim::PimCoreBatchOp>().getLaneCount())) {
|
|
||||||
reportFailure([](Operation* illegalOp) {
|
|
||||||
illegalOp->emitOpError("sourceCoreIds size must match parent laneCount");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return success(!hasFailure);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto verifyTensorBatchCommunication = [&](Value tensorValue, ArrayRef<int32_t> coreIds, StringRef kind) {
|
|
||||||
if (coreIds.empty()) {
|
|
||||||
reportFailure([&](Operation* illegalOp) { illegalOp->emitOpError() << kind << " must carry at least one chunk"; });
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto parentBatchOp = op.getParentOfType<pim::PimCoreBatchOp>();
|
|
||||||
int32_t laneCount = parentBatchOp.getLaneCount();
|
|
||||||
if (laneCount <= 0) {
|
|
||||||
reportFailure([&](Operation* illegalOp) {
|
|
||||||
illegalOp->emitOpError() << kind << " requires a positive parent laneCount";
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (coreIds.size() % static_cast<size_t>(laneCount) != 0) {
|
|
||||||
reportFailure([&](Operation* illegalOp) {
|
|
||||||
illegalOp->emitOpError() << kind << " core id count must be divisible by the parent laneCount";
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto shapedType = getStaticByteSizedShapedType(tensorValue.getType());
|
|
||||||
if (failed(shapedType)) {
|
|
||||||
reportFailure([&](Operation* illegalOp) {
|
|
||||||
illegalOp->emitOpError() << kind << " requires a static shaped tensor or memref with byte-sized elements";
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
|
|
||||||
int64_t totalBytes = (*shapedType).getNumElements() * (*shapedType).getElementTypeBitWidth() / 8;
|
|
||||||
if (totalBytes % chunkCount != 0) {
|
|
||||||
reportFailure([&](Operation* illegalOp) {
|
|
||||||
illegalOp->emitOpError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op))
|
|
||||||
verifyTensorBatchCommunication(sendTensorBatchOp.getInput(),
|
|
||||||
sendTensorBatchOp.getTargetCoreIds(),
|
|
||||||
"send_tensor_batch");
|
|
||||||
else if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op))
|
|
||||||
verifyTensorBatchCommunication(receiveTensorBatchOp.getOutput(),
|
|
||||||
receiveTensorBatchOp.getSourceCoreIds(),
|
|
||||||
"receive_tensor_batch");
|
|
||||||
|
|
||||||
return success(!hasFailure);
|
return success(!hasFailure);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user