cleanup unused channel operations and related logic
Validate Operations / validate-operations (push) Waiting to run
Validate Operations / validate-operations (push) Waiting to run
This commit is contained in:
@@ -519,61 +519,12 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
|
||||
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||
const StaticValueKnowledge& knowledge) const {
|
||||
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
|
||||
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveTensorOp.getOutputBuffer().getType()))
|
||||
/ receiveTensorOp.getSourceCoreIds().size();
|
||||
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
|
||||
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveBatchOp(pim::PimReceiveBatchOp receiveOp,
|
||||
unsigned lane,
|
||||
const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp(
|
||||
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreIds()[lane], receiveOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveTensorBatchOp(pim::PimReceiveTensorBatchOp receiveOp,
|
||||
ArrayRef<int32_t> laneCoreIds,
|
||||
const StaticValueKnowledge& knowledge) const {
|
||||
size_t outputAddr = addressOf(receiveOp.getOutputBuffer(), knowledge);
|
||||
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveOp.getOutputBuffer().getType()))
|
||||
/ laneCoreIds.size();
|
||||
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(laneCoreIds))
|
||||
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto targetCoreId = indexOf(sendOp.getTargetCoreId(), knowledge);
|
||||
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
||||
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
|
||||
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendTensorOp.getInput().getType()))
|
||||
/ sendTensorOp.getTargetCoreIds().size();
|
||||
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
|
||||
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendBatchOp(pim::PimSendBatchOp sendOp,
|
||||
unsigned lane,
|
||||
const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreIds()[lane], sendOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendTensorBatchOp(pim::PimSendTensorBatchOp sendOp,
|
||||
ArrayRef<int32_t> laneCoreIds,
|
||||
const StaticValueKnowledge& knowledge) const {
|
||||
size_t inputAddr = addressOf(sendOp.getInput(), knowledge);
|
||||
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendOp.getInput().getType())) / laneCoreIds.size();
|
||||
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(laneCoreIds))
|
||||
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
|
||||
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
|
||||
@@ -902,13 +853,7 @@ enum class CompiledCoreOpKind : uint8_t {
|
||||
Store,
|
||||
Lmv,
|
||||
Receive,
|
||||
ReceiveBatch,
|
||||
ReceiveTensor,
|
||||
ReceiveTensorBatch,
|
||||
Send,
|
||||
SendBatch,
|
||||
SendTensor,
|
||||
SendTensorBatch,
|
||||
Concat,
|
||||
Vmm,
|
||||
Transpose,
|
||||
@@ -952,20 +897,8 @@ static FailureOr<CompiledCoreOpKind> classifyCompiledCoreOpKind(Operation& op) {
|
||||
return CompiledCoreOpKind::Lmv;
|
||||
if (isa<pim::PimReceiveOp>(op))
|
||||
return CompiledCoreOpKind::Receive;
|
||||
if (isa<pim::PimReceiveBatchOp>(op))
|
||||
return CompiledCoreOpKind::ReceiveBatch;
|
||||
if (isa<pim::PimReceiveTensorOp>(op))
|
||||
return CompiledCoreOpKind::ReceiveTensor;
|
||||
if (isa<pim::PimReceiveTensorBatchOp>(op))
|
||||
return CompiledCoreOpKind::ReceiveTensorBatch;
|
||||
if (isa<pim::PimSendOp>(op))
|
||||
return CompiledCoreOpKind::Send;
|
||||
if (isa<pim::PimSendBatchOp>(op))
|
||||
return CompiledCoreOpKind::SendBatch;
|
||||
if (isa<pim::PimSendTensorOp>(op))
|
||||
return CompiledCoreOpKind::SendTensor;
|
||||
if (isa<pim::PimSendTensorBatchOp>(op))
|
||||
return CompiledCoreOpKind::SendTensorBatch;
|
||||
if (isa<pim::PimConcatOp>(op))
|
||||
return CompiledCoreOpKind::Concat;
|
||||
if (isa<pim::PimVMMOp>(op))
|
||||
@@ -1108,43 +1041,9 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
|
||||
case CompiledCoreOpKind::Receive:
|
||||
coreCodeGen.codeGenReceiveOp(cast<pim::PimReceiveOp>(node.op), knowledge);
|
||||
break;
|
||||
case CompiledCoreOpKind::ReceiveBatch:
|
||||
if (!batchLane)
|
||||
return failure();
|
||||
coreCodeGen.codeGenReceiveBatchOp(cast<pim::PimReceiveBatchOp>(node.op), *batchLane, knowledge);
|
||||
break;
|
||||
case CompiledCoreOpKind::ReceiveTensor:
|
||||
coreCodeGen.codeGenReceiveTensorOp(cast<pim::PimReceiveTensorOp>(node.op), knowledge);
|
||||
break;
|
||||
case CompiledCoreOpKind::ReceiveTensorBatch:
|
||||
if (!batchLane || !batchLaneCount)
|
||||
return failure();
|
||||
coreCodeGen.codeGenReceiveTensorBatchOp(cast<pim::PimReceiveTensorBatchOp>(node.op),
|
||||
getLaneChunkCoreIds(cast<pim::PimReceiveTensorBatchOp>(node.op).getSourceCoreIds(),
|
||||
*batchLaneCount,
|
||||
*batchLane),
|
||||
knowledge);
|
||||
break;
|
||||
case CompiledCoreOpKind::Send:
|
||||
coreCodeGen.codeGenSendOp(cast<pim::PimSendOp>(node.op), knowledge);
|
||||
break;
|
||||
case CompiledCoreOpKind::SendBatch:
|
||||
if (!batchLane)
|
||||
return failure();
|
||||
coreCodeGen.codeGenSendBatchOp(cast<pim::PimSendBatchOp>(node.op), *batchLane, knowledge);
|
||||
break;
|
||||
case CompiledCoreOpKind::SendTensor:
|
||||
coreCodeGen.codeGenSendTensorOp(cast<pim::PimSendTensorOp>(node.op), knowledge);
|
||||
break;
|
||||
case CompiledCoreOpKind::SendTensorBatch:
|
||||
if (!batchLane || !batchLaneCount)
|
||||
return failure();
|
||||
coreCodeGen.codeGenSendTensorBatchOp(cast<pim::PimSendTensorBatchOp>(node.op),
|
||||
getLaneChunkCoreIds(cast<pim::PimSendTensorBatchOp>(node.op).getTargetCoreIds(),
|
||||
*batchLaneCount,
|
||||
*batchLane),
|
||||
knowledge);
|
||||
break;
|
||||
case CompiledCoreOpKind::Concat:
|
||||
coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge);
|
||||
break;
|
||||
|
||||
@@ -178,17 +178,7 @@ public:
|
||||
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
|
||||
|
||||
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenReceiveBatchOp(pim::PimReceiveBatchOp receiveOp, unsigned lane, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenReceiveTensorBatchOp(pim::PimReceiveTensorBatchOp receiveOp,
|
||||
llvm::ArrayRef<int32_t> laneCoreIds,
|
||||
const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenSendBatchOp(pim::PimSendBatchOp sendOp, unsigned lane, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenSendTensorBatchOp(pim::PimSendTensorBatchOp sendOp,
|
||||
llvm::ArrayRef<int32_t> laneCoreIds,
|
||||
const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
|
||||
|
||||
template <typename MVMTy>
|
||||
|
||||
@@ -18,27 +18,6 @@ using namespace onnx_mlir::pim;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
|
||||
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||
SmallVector<int32_t> constants;
|
||||
constants.reserve(values.size());
|
||||
for (Value value : values) {
|
||||
FailureOr<int32_t> constantValue = getConstantI32Value(value);
|
||||
if (failed(constantValue))
|
||||
return failure();
|
||||
constants.push_back(*constantValue);
|
||||
}
|
||||
return constants;
|
||||
}
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 2;
|
||||
@@ -62,43 +41,6 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
static LogicalResult lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendTensorBatchOp.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
return sendTensorBatchOp.emitOpError("expected constant targetCoreIds");
|
||||
for (int32_t& targetCoreId : *targetCoreIds)
|
||||
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
|
||||
|
||||
pim::PimSendTensorBatchOp::create(rewriter,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorBatchOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveTensorBatchOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
|
||||
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
|
||||
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
mapper.map(receiveTensorBatchOp.getOutput(), received);
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
||||
if (!result.hasOneUse())
|
||||
return failure();
|
||||
@@ -304,51 +246,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
return sendBatchOp.emitOpError("expected constant targetCoreIds");
|
||||
for (int32_t& targetCoreId : *targetCoreIds)
|
||||
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
|
||||
pim::PimSendBatchOp::create(rewriter,
|
||||
loc,
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
|
||||
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
|
||||
if (failed(lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter)))
|
||||
return failure();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveBatchOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveBatchOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
||||
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
mapper.map(receiveBatchOp.getOutput(), received);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
|
||||
if (failed(lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter)))
|
||||
return failure();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
||||
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
@@ -11,20 +9,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||
SmallVector<int32_t> constants;
|
||||
constants.reserve(values.size());
|
||||
for (Value value : values) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
|
||||
}
|
||||
return constants;
|
||||
}
|
||||
|
||||
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -59,42 +43,6 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(op.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
return rewriter.notifyMatchFailure(op, "expected constant targetCoreIds");
|
||||
for (int32_t& targetCoreId : *targetCoreIds)
|
||||
targetCoreId = toPimCoreId(targetCoreId);
|
||||
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(op.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return rewriter.notifyMatchFailure(op, "expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = toPimCoreId(sourceCoreId);
|
||||
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
||||
Value outputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||
Value received =
|
||||
pim::PimReceiveTensorOp::create(
|
||||
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -137,12 +85,7 @@ struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
|
||||
} // namespace
|
||||
|
||||
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<ChannelSendLowering,
|
||||
ChannelReceiveLowering,
|
||||
ChannelSendTensorLowering,
|
||||
ChannelReceiveTensorLowering,
|
||||
ExtractRowsLowering,
|
||||
ConcatLowering>(patterns.getContext());
|
||||
patterns.add<ChannelSendLowering, ChannelReceiveLowering, ExtractRowsLowering, ConcatLowering>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -53,20 +53,6 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
|
||||
}
|
||||
}
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||
SmallVector<int32_t> constants;
|
||||
constants.reserve(values.size());
|
||||
for (Value value : values) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
|
||||
}
|
||||
return constants;
|
||||
}
|
||||
|
||||
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
||||
@@ -186,25 +172,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
continue;
|
||||
}
|
||||
|
||||
auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp());
|
||||
if (receiveTensorOp && !blockArg->use_empty()) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveTensorOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg->getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType);
|
||||
Value received = PimReceiveTensorOp::create(rewriter,
|
||||
receiveTensorOp.getLoc(),
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
blockArg->replaceAllUsesWith(received);
|
||||
markOpToRemove(receiveTensorOp);
|
||||
}
|
||||
}
|
||||
|
||||
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
||||
|
||||
@@ -607,8 +607,6 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
|
||||
markOpToRemove(receiveTensorOp);
|
||||
};
|
||||
|
||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||
|
||||
@@ -156,11 +156,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
BuiltinDialect>();
|
||||
target.addLegalOp<spatial::SpatConcatOp,
|
||||
spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatChannelReceiveTensorOp,
|
||||
spatial::SpatChannelReceiveTensorBatchOp,
|
||||
spatial::SpatChannelSendOp,
|
||||
spatial::SpatChannelSendTensorOp,
|
||||
spatial::SpatChannelSendTensorBatchOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
RewritePatternSet initialPatterns(ctx);
|
||||
@@ -234,11 +230,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
BuiltinDialect>();
|
||||
coreBodyTarget.addLegalOp<spatial::SpatConcatOp,
|
||||
spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatChannelReceiveTensorOp,
|
||||
spatial::SpatChannelReceiveTensorBatchOp,
|
||||
spatial::SpatChannelSendOp,
|
||||
spatial::SpatChannelSendTensorOp,
|
||||
spatial::SpatChannelSendTensorBatchOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
SmallVector<pim::PimCoreOp> coreOps;
|
||||
@@ -282,9 +274,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
communicationTarget.addLegalOp<ModuleOp>();
|
||||
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
|
||||
spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatChannelReceiveTensorOp,
|
||||
spatial::SpatChannelSendOp,
|
||||
spatial::SpatChannelSendTensorOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
RewritePatternSet communicationPatterns(ctx);
|
||||
|
||||
@@ -102,42 +102,6 @@ def PimSendOp : PimOp<"send", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimSendTensorOp : PimOp<"send_tensor", []> {
|
||||
let summary = "Send equal contiguous chunks of one tensor to target cores";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimSendBatchOp : PimOp<"send_batch", []> {
|
||||
let summary = "Send a per-lane tensor to target cores from a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
I32Attr:$size,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimSendTensorBatchOp : PimOp<"send_tensor_batch", []> {
|
||||
let summary = "Send equal contiguous chunks of one per-lane tensor from a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive a tensor from another core";
|
||||
|
||||
@@ -162,72 +126,6 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimReceiveTensorOp : PimOp<"receive_tensor", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive equal contiguous chunks from source cores into one tensor";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$outputBuffer,
|
||||
DenseI32ArrayAttr:$sourceCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive per-lane tensors from source cores into a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$outputBuffer,
|
||||
I32Attr:$size,
|
||||
DenseI32ArrayAttr:$sourceCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimReceiveTensorBatchOp : PimOp<"receive_tensor_batch", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive equal contiguous chunks into one per-lane tensor inside a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$outputBuffer,
|
||||
DenseI32ArrayAttr:$sourceCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from host memory into device memory";
|
||||
|
||||
|
||||
@@ -28,34 +28,6 @@ static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred,
|
||||
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||
}
|
||||
|
||||
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
||||
printer << "(";
|
||||
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printOperand(argument);
|
||||
}
|
||||
printer << ")";
|
||||
}
|
||||
|
||||
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRParen()))
|
||||
return success();
|
||||
|
||||
OpAsmParser::Argument argument;
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
}
|
||||
return parser.parseRParen();
|
||||
}
|
||||
|
||||
static void
|
||||
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||
printCompressedValueList(printer, arguments, delimiter);
|
||||
@@ -98,12 +70,6 @@ static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<i
|
||||
printCompressedIntegerList(printer, coreIds);
|
||||
}
|
||||
|
||||
static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keyword, SmallVectorImpl<int32_t>& coreIds) {
|
||||
if (failed(parser.parseOptionalKeyword(keyword)))
|
||||
return success();
|
||||
return parseCompressedIntegerList(parser, coreIds);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void PimCoreOp::print(OpAsmPrinter& printer) {
|
||||
@@ -295,198 +261,6 @@ ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void PimSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!targetCoreIds.empty())
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void PimSendTensorBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!targetCoreIds.empty())
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void PimSendTensorOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!targetCoreIds.empty())
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void PimReceiveTensorOp::print(OpAsmPrinter& printer) {
|
||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||
printer << " into ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOperand(getOutputBuffer());
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getOutputBuffer().getType());
|
||||
printer << " -> ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand outputBuffer;
|
||||
Type outputBufferType;
|
||||
Type outputType;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
|
||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
||||
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|
||||
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|
||||
|| parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!sourceCoreIds.empty())
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
|
||||
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||
printer << " into ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOperand(getOutputBuffer());
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getOutputBuffer().getType());
|
||||
printer << " -> ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand outputBuffer;
|
||||
Type outputBufferType;
|
||||
Type outputType;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
|
||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
||||
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|
||||
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|
||||
|| parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!sourceCoreIds.empty())
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
|
||||
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimReceiveTensorBatchOp::print(OpAsmPrinter& printer) {
|
||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||
printer << " into ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOperand(getOutputBuffer());
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getOutputBuffer().getType());
|
||||
printer << " -> ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand outputBuffer;
|
||||
Type outputBufferType;
|
||||
Type outputType;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
|
||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
||||
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|
||||
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|
||||
|| parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!sourceCoreIds.empty())
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
|
||||
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimConcatOp::print(OpAsmPrinter& printer) {
|
||||
printer << " axis " << getAxis() << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
|
||||
@@ -90,56 +90,6 @@ static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type r
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
|
||||
if (coreIds.empty())
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor or memref";
|
||||
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % static_cast<int64_t>(coreIds.size()) != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
|
||||
if (coreIds.empty())
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
|
||||
if (!coreBatchOp)
|
||||
return op->emitError() << kind << " must be nested inside pim.core_batch";
|
||||
|
||||
int32_t laneCount = coreBatchOp.getLaneCount();
|
||||
if (laneCount <= 0)
|
||||
return op->emitError() << kind << " requires a positive parent laneCount";
|
||||
if (coreIds.size() % static_cast<size_t>(laneCount) != 0)
|
||||
return op->emitError() << kind << " core id count must be divisible by the parent laneCount";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor or memref";
|
||||
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % chunkCount != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
|
||||
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
||||
if (!shapedType)
|
||||
@@ -177,31 +127,6 @@ LogicalResult PimCoreBatchOp::verify() {
|
||||
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult PimSendTensorBatchOp::verify() {
|
||||
return verifyTensorBatchCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveTensorOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveTensorBatchOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
return verifyTensorBatchCommunication(
|
||||
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimVMMOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
|
||||
@@ -157,72 +157,6 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveBatchOpInterface, PimReceiveBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto receiveOp = cast<PimReceiveBatchOp>(op);
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveBatchOp>(rewriter,
|
||||
op,
|
||||
outputBufferOpt->getType(),
|
||||
*outputBufferOpt,
|
||||
receiveOp.getSizeAttr(),
|
||||
receiveOp.getSourceCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveTensorOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorOpInterface, PimReceiveTensorOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto receiveOp = cast<PimReceiveTensorOp>(op);
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveTensorOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveTensorBatchOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorBatchOpInterface, PimReceiveTensorBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto receiveOp = cast<PimReceiveTensorBatchOp>(op);
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveTensorBatchOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInterface, PimConcatOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -252,30 +186,6 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
|
||||
}
|
||||
};
|
||||
|
||||
struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel<SendTensorOpInterface, PimSendTensorOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto sendOp = cast<PimSendTensorOp>(op);
|
||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimSendTensorOp>(
|
||||
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface, PimSendOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
@@ -303,58 +213,6 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
|
||||
}
|
||||
};
|
||||
|
||||
struct SendBatchOpInterface : BufferizableOpInterface::ExternalModel<SendBatchOpInterface, PimSendBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto sendOp = cast<PimSendBatchOp>(op);
|
||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimSendBatchOp>(rewriter,
|
||||
op,
|
||||
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
|
||||
sendOp.getSizeAttr(),
|
||||
sendOp.getTargetCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct SendTensorBatchOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<SendTensorBatchOpInterface, PimSendTensorBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto sendOp = cast<PimSendTensorBatchOp>(op);
|
||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimSendTensorBatchOp>(
|
||||
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface, PimCoreOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
@@ -699,13 +557,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimCoreOp::attachInterface<CoreOpInterface>(*ctx);
|
||||
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
|
||||
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||
PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx);
|
||||
PimReceiveTensorBatchOp::attachInterface<ReceiveTensorBatchOpInterface>(*ctx);
|
||||
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
|
||||
PimSendOp::attachInterface<SendOpInterface>(*ctx);
|
||||
PimSendBatchOp::attachInterface<SendBatchOpInterface>(*ctx);
|
||||
PimSendTensorBatchOp::attachInterface<SendTensorBatchOpInterface>(*ctx);
|
||||
PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx);
|
||||
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
||||
|
||||
@@ -194,111 +194,6 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> {
|
||||
let summary = "Send equal contiguous chunks of one tensor through logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> {
|
||||
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Send per-lane tensors through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Receive a per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Math
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -95,13 +95,6 @@ static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
|
||||
return shapedType.getShape();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
auto batchOp = op->getParentOfType<SpatComputeBatch>();
|
||||
if (!batchOp)
|
||||
return failure();
|
||||
return batchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
||||
if (batchOp.getNumResults() == 0)
|
||||
return false;
|
||||
@@ -233,68 +226,6 @@ static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::Paralle
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorChannelSizes(
|
||||
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
if (channelCount == 0)
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor";
|
||||
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % static_cast<int64_t>(channelCount) != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelCount != static_cast<size_t>(*laneCount))
|
||||
return op->emitError("channel metadata length must match parent laneCount");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorBatchChannelSizes(
|
||||
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
|
||||
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor";
|
||||
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % chunkCount != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static Region* getParentRegion(Value value) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return blockArg.getOwner()->getParent();
|
||||
@@ -564,52 +495,6 @@ LogicalResult SpatCompute::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(
|
||||
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_send_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(
|
||||
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_receive_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatComputeBatch::verify() {
|
||||
int32_t count = getLaneCount();
|
||||
if (count <= 0)
|
||||
|
||||
@@ -79,8 +79,6 @@ struct MergeIrCounts {
|
||||
uint64_t topLevelComputeBatchCount = 0;
|
||||
uint64_t scalarChannelSendCount = 0;
|
||||
uint64_t scalarChannelReceiveCount = 0;
|
||||
uint64_t tensorChannelSendCount = 0;
|
||||
uint64_t tensorChannelReceiveCount = 0;
|
||||
uint64_t wvmmCount = 0;
|
||||
uint64_t vaddCount = 0;
|
||||
uint64_t scfForCount = 0;
|
||||
@@ -95,10 +93,6 @@ MergeIrCounts collectMergeIrCounts(func::FuncOp funcOp) {
|
||||
++counts.scalarChannelSendCount;
|
||||
else if (isa<spatial::SpatChannelReceiveOp>(nestedOp))
|
||||
++counts.scalarChannelReceiveCount;
|
||||
else if (isa<spatial::SpatChannelSendTensorOp, spatial::SpatChannelSendTensorBatchOp>(nestedOp))
|
||||
++counts.tensorChannelSendCount;
|
||||
else if (isa<spatial::SpatChannelReceiveTensorOp, spatial::SpatChannelReceiveTensorBatchOp>(nestedOp))
|
||||
++counts.tensorChannelReceiveCount;
|
||||
else if (isa<spatial::SpatVMMOp>(nestedOp))
|
||||
++counts.wvmmCount;
|
||||
else if (isa<spatial::SpatVAddOp>(nestedOp))
|
||||
@@ -130,9 +124,8 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
|
||||
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
|
||||
<< " scalar_send=" << counts.scalarChannelSendCount
|
||||
<< " scalar_recv=" << counts.scalarChannelReceiveCount
|
||||
<< " tensor_send=" << counts.tensorChannelSendCount
|
||||
<< " tensor_recv=" << counts.tensorChannelReceiveCount << " wvmm=" << counts.wvmmCount
|
||||
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
|
||||
<< " wvmm=" << counts.wvmmCount << " vadd=" << counts.vaddCount
|
||||
<< " scf_for=" << counts.scfForCount << "\n";
|
||||
}
|
||||
|
||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
|
||||
@@ -150,13 +150,7 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
|
||||
pim::PimMemCopyDevToHostOp,
|
||||
pim::PimMemCopyOp,
|
||||
pim::PimReceiveOp,
|
||||
pim::PimReceiveBatchOp,
|
||||
pim::PimReceiveTensorOp,
|
||||
pim::PimReceiveTensorBatchOp,
|
||||
pim::PimSendOp,
|
||||
pim::PimSendBatchOp,
|
||||
pim::PimSendTensorOp,
|
||||
pim::PimSendTensorBatchOp,
|
||||
pim::PimConcatOp,
|
||||
pim::PimVMMOp,
|
||||
pim::PimTransposeOp,
|
||||
@@ -173,18 +167,6 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
|
||||
memref::GetGlobalOp>(op);
|
||||
}
|
||||
|
||||
static FailureOr<ShapedType> getStaticByteSizedShapedType(Type type) {
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return failure();
|
||||
|
||||
return shapedType;
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchOpSemantics(Operation& op,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
pim::CappedDiagnosticReporter& diagnostics) {
|
||||
@@ -203,73 +185,6 @@ static LogicalResult verifyBatchOpSemantics(Operation& op,
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||
if (sendBatchOp.getTargetCoreIds().size() != static_cast<size_t>(sendBatchOp->getParentOfType<pim::PimCoreBatchOp>()
|
||||
.getLaneCount())) {
|
||||
reportFailure([](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("targetCoreIds size must match parent laneCount");
|
||||
});
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||
if (receiveBatchOp.getSourceCoreIds().size()
|
||||
!= static_cast<size_t>(receiveBatchOp->getParentOfType<pim::PimCoreBatchOp>().getLaneCount())) {
|
||||
reportFailure([](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("sourceCoreIds size must match parent laneCount");
|
||||
});
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
auto verifyTensorBatchCommunication = [&](Value tensorValue, ArrayRef<int32_t> coreIds, StringRef kind) {
|
||||
if (coreIds.empty()) {
|
||||
reportFailure([&](Operation* illegalOp) { illegalOp->emitOpError() << kind << " must carry at least one chunk"; });
|
||||
return;
|
||||
}
|
||||
|
||||
auto parentBatchOp = op.getParentOfType<pim::PimCoreBatchOp>();
|
||||
int32_t laneCount = parentBatchOp.getLaneCount();
|
||||
if (laneCount <= 0) {
|
||||
reportFailure([&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << kind << " requires a positive parent laneCount";
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (coreIds.size() % static_cast<size_t>(laneCount) != 0) {
|
||||
reportFailure([&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << kind << " core id count must be divisible by the parent laneCount";
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
auto shapedType = getStaticByteSizedShapedType(tensorValue.getType());
|
||||
if (failed(shapedType)) {
|
||||
reportFailure([&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << kind << " requires a static shaped tensor or memref with byte-sized elements";
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
|
||||
int64_t totalBytes = (*shapedType).getNumElements() * (*shapedType).getElementTypeBitWidth() / 8;
|
||||
if (totalBytes % chunkCount != 0) {
|
||||
reportFailure([&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op))
|
||||
verifyTensorBatchCommunication(sendTensorBatchOp.getInput(),
|
||||
sendTensorBatchOp.getTargetCoreIds(),
|
||||
"send_tensor_batch");
|
||||
else if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op))
|
||||
verifyTensorBatchCommunication(receiveTensorBatchOp.getOutput(),
|
||||
receiveTensorBatchOp.getSourceCoreIds(),
|
||||
"receive_tensor_batch");
|
||||
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user