cleanup unused channel operations and related logic
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-05-25 20:58:51 +02:00
parent bdc4ca33f3
commit 0f240af271
15 changed files with 3 additions and 1182 deletions
-101
View File
@@ -519,61 +519,12 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
}
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
const StaticValueKnowledge& knowledge) const {
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveTensorOp.getOutputBuffer().getType()))
/ receiveTensorOp.getSourceCoreIds().size();
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
}
void PimCodeGen::codeGenReceiveBatchOp(pim::PimReceiveBatchOp receiveOp,
unsigned lane,
const StaticValueKnowledge& knowledge) const {
emitCommunicationOp(
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreIds()[lane], receiveOp.getSize());
}
void PimCodeGen::codeGenReceiveTensorBatchOp(pim::PimReceiveTensorBatchOp receiveOp,
ArrayRef<int32_t> laneCoreIds,
const StaticValueKnowledge& knowledge) const {
size_t outputAddr = addressOf(receiveOp.getOutputBuffer(), knowledge);
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveOp.getOutputBuffer().getType()))
/ laneCoreIds.size();
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(laneCoreIds))
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
auto targetCoreId = indexOf(sendOp.getTargetCoreId(), knowledge);
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
}
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendTensorOp.getInput().getType()))
/ sendTensorOp.getTargetCoreIds().size();
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
}
void PimCodeGen::codeGenSendBatchOp(pim::PimSendBatchOp sendOp,
unsigned lane,
const StaticValueKnowledge& knowledge) const {
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreIds()[lane], sendOp.getSize());
}
void PimCodeGen::codeGenSendTensorBatchOp(pim::PimSendTensorBatchOp sendOp,
ArrayRef<int32_t> laneCoreIds,
const StaticValueKnowledge& knowledge) const {
size_t inputAddr = addressOf(sendOp.getInput(), knowledge);
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendOp.getInput().getType())) / laneCoreIds.size();
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(laneCoreIds))
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
}
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
@@ -902,13 +853,7 @@ enum class CompiledCoreOpKind : uint8_t {
Store,
Lmv,
Receive,
ReceiveBatch,
ReceiveTensor,
ReceiveTensorBatch,
Send,
SendBatch,
SendTensor,
SendTensorBatch,
Concat,
Vmm,
Transpose,
@@ -952,20 +897,8 @@ static FailureOr<CompiledCoreOpKind> classifyCompiledCoreOpKind(Operation& op) {
return CompiledCoreOpKind::Lmv;
if (isa<pim::PimReceiveOp>(op))
return CompiledCoreOpKind::Receive;
if (isa<pim::PimReceiveBatchOp>(op))
return CompiledCoreOpKind::ReceiveBatch;
if (isa<pim::PimReceiveTensorOp>(op))
return CompiledCoreOpKind::ReceiveTensor;
if (isa<pim::PimReceiveTensorBatchOp>(op))
return CompiledCoreOpKind::ReceiveTensorBatch;
if (isa<pim::PimSendOp>(op))
return CompiledCoreOpKind::Send;
if (isa<pim::PimSendBatchOp>(op))
return CompiledCoreOpKind::SendBatch;
if (isa<pim::PimSendTensorOp>(op))
return CompiledCoreOpKind::SendTensor;
if (isa<pim::PimSendTensorBatchOp>(op))
return CompiledCoreOpKind::SendTensorBatch;
if (isa<pim::PimConcatOp>(op))
return CompiledCoreOpKind::Concat;
if (isa<pim::PimVMMOp>(op))
@@ -1108,43 +1041,9 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
case CompiledCoreOpKind::Receive:
coreCodeGen.codeGenReceiveOp(cast<pim::PimReceiveOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::ReceiveBatch:
if (!batchLane)
return failure();
coreCodeGen.codeGenReceiveBatchOp(cast<pim::PimReceiveBatchOp>(node.op), *batchLane, knowledge);
break;
case CompiledCoreOpKind::ReceiveTensor:
coreCodeGen.codeGenReceiveTensorOp(cast<pim::PimReceiveTensorOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::ReceiveTensorBatch:
if (!batchLane || !batchLaneCount)
return failure();
coreCodeGen.codeGenReceiveTensorBatchOp(cast<pim::PimReceiveTensorBatchOp>(node.op),
getLaneChunkCoreIds(cast<pim::PimReceiveTensorBatchOp>(node.op).getSourceCoreIds(),
*batchLaneCount,
*batchLane),
knowledge);
break;
case CompiledCoreOpKind::Send:
coreCodeGen.codeGenSendOp(cast<pim::PimSendOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::SendBatch:
if (!batchLane)
return failure();
coreCodeGen.codeGenSendBatchOp(cast<pim::PimSendBatchOp>(node.op), *batchLane, knowledge);
break;
case CompiledCoreOpKind::SendTensor:
coreCodeGen.codeGenSendTensorOp(cast<pim::PimSendTensorOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::SendTensorBatch:
if (!batchLane || !batchLaneCount)
return failure();
coreCodeGen.codeGenSendTensorBatchOp(cast<pim::PimSendTensorBatchOp>(node.op),
getLaneChunkCoreIds(cast<pim::PimSendTensorBatchOp>(node.op).getTargetCoreIds(),
*batchLaneCount,
*batchLane),
knowledge);
break;
case CompiledCoreOpKind::Concat:
coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge);
break;
-10
View File
@@ -178,17 +178,7 @@ public:
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveBatchOp(pim::PimReceiveBatchOp receiveOp, unsigned lane, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveTensorBatchOp(pim::PimReceiveTensorBatchOp receiveOp,
llvm::ArrayRef<int32_t> laneCoreIds,
const StaticValueKnowledge& knowledge) const;
void codeGenSendBatchOp(pim::PimSendBatchOp sendOp, unsigned lane, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorBatchOp(pim::PimSendTensorBatchOp sendOp,
llvm::ArrayRef<int32_t> laneCoreIds,
const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy>
@@ -18,27 +18,6 @@ using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static FailureOr<int32_t> getConstantI32Value(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return static_cast<int32_t>(constantValue.getSExtValue());
}
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
SmallVector<int32_t> constants;
constants.reserve(values.size());
for (Value value : values) {
FailureOr<int32_t> constantValue = getConstantI32Value(value);
if (failed(constantValue))
return failure();
constants.push_back(*constantValue);
}
return constants;
}
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
@@ -62,43 +41,6 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
return coreIds;
}
static LogicalResult lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendTensorBatchOp.getTargetCoreIds());
if (failed(targetCoreIds))
return sendTensorBatchOp.emitOpError("expected constant targetCoreIds");
for (int32_t& targetCoreId : *targetCoreIds)
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
pim::PimSendTensorBatchOp::create(rewriter,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
return success();
}
static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorBatchOp.getSourceCoreIds());
if (failed(sourceCoreIds))
return receiveTensorBatchOp.emitOpError("expected constant sourceCoreIds");
for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
receiveTensorBatchOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput();
mapper.map(receiveTensorBatchOp.getOutput(), received);
return success();
}
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
if (!result.hasOneUse())
return failure();
@@ -304,51 +246,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
continue;
}
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
if (failed(targetCoreIds))
return sendBatchOp.emitOpError("expected constant targetCoreIds");
for (int32_t& targetCoreId : *targetCoreIds)
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
pim::PimSendBatchOp::create(rewriter,
loc,
mapper.lookup(sendBatchOp.getInput()),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
continue;
}
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
if (failed(lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter)))
return failure();
continue;
}
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveBatchOp.getSourceCoreIds());
if (failed(sourceCoreIds))
return receiveBatchOp.emitOpError("expected constant sourceCoreIds");
for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput();
mapper.map(receiveBatchOp.getOutput(), received);
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
if (failed(lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter)))
return failure();
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper);
@@ -1,6 +1,4 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -11,20 +9,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
SmallVector<int32_t> constants;
constants.reserve(values.size());
for (Value value : values) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
}
return constants;
}
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern;
@@ -59,42 +43,6 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
}
};
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(op.getTargetCoreIds());
if (failed(targetCoreIds))
return rewriter.notifyMatchFailure(op, "expected constant targetCoreIds");
for (int32_t& targetCoreId : *targetCoreIds)
targetCoreId = toPimCoreId(targetCoreId);
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(*targetCoreIds));
rewriter.eraseOp(op);
return success();
}
};
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(op.getSourceCoreIds());
if (failed(sourceCoreIds))
return rewriter.notifyMatchFailure(op, "expected constant sourceCoreIds");
for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = toPimCoreId(sourceCoreId);
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received =
pim::PimReceiveTensorOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
using OpRewritePattern::OpRewritePattern;
@@ -137,12 +85,7 @@ struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
} // namespace
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
patterns.add<ChannelSendLowering,
ChannelReceiveLowering,
ChannelSendTensorLowering,
ChannelReceiveTensorLowering,
ExtractRowsLowering,
ConcatLowering>(patterns.getContext());
patterns.add<ChannelSendLowering, ChannelReceiveLowering, ExtractRowsLowering, ConcatLowering>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -53,20 +53,6 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
}
}
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
SmallVector<int32_t> constants;
constants.reserve(values.size());
for (Value value : values) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
}
return constants;
}
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
@@ -186,25 +172,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
continue;
}
auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp());
if (receiveTensorOp && !blockArg->use_empty()) {
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds());
if (failed(sourceCoreIds))
return receiveTensorOp.emitOpError("expected constant sourceCoreIds");
for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
auto outputType = cast<ShapedType>(blockArg->getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType);
Value received = PimReceiveTensorOp::create(rewriter,
receiveTensorOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput();
blockArg->replaceAllUsesWith(received);
markOpToRemove(receiveTensorOp);
}
}
if (computeOp.getNumResults() != yieldOp.getNumOperands())
@@ -607,8 +607,6 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
return;
}
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
markOpToRemove(receiveTensorOp);
};
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
@@ -156,11 +156,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
BuiltinDialect>();
target.addLegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelReceiveTensorBatchOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatChannelSendTensorBatchOp,
spatial::SpatExtractRowsOp>();
RewritePatternSet initialPatterns(ctx);
@@ -234,11 +230,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
BuiltinDialect>();
coreBodyTarget.addLegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelReceiveTensorBatchOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatChannelSendTensorBatchOp,
spatial::SpatExtractRowsOp>();
SmallVector<pim::PimCoreOp> coreOps;
@@ -282,9 +274,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
communicationTarget.addLegalOp<ModuleOp>();
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatExtractRowsOp>();
RewritePatternSet communicationPatterns(ctx);
-102
View File
@@ -102,42 +102,6 @@ def PimSendOp : PimOp<"send", []> {
}];
}
def PimSendTensorOp : PimOp<"send_tensor", []> {
let summary = "Send equal contiguous chunks of one tensor to target cores";
let arguments = (ins
PimTensor:$input,
DenseI32ArrayAttr:$targetCoreIds
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimSendBatchOp : PimOp<"send_batch", []> {
let summary = "Send a per-lane tensor to target cores from a batched core";
let arguments = (ins
PimTensor:$input,
I32Attr:$size,
DenseI32ArrayAttr:$targetCoreIds
);
let hasCustomAssemblyFormat = 1;
}
def PimSendTensorBatchOp : PimOp<"send_tensor_batch", []> {
let summary = "Send equal contiguous chunks of one per-lane tensor from a batched core";
let arguments = (ins
PimTensor:$input,
DenseI32ArrayAttr:$targetCoreIds
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
let summary = "Receive a tensor from another core";
@@ -162,72 +126,6 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
}];
}
def PimReceiveTensorOp : PimOp<"receive_tensor", [DestinationStyleOpInterface]> {
let summary = "Receive equal contiguous chunks from source cores into one tensor";
let arguments = (ins
PimTensor:$outputBuffer,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
let summary = "Receive per-lane tensors from source cores into a batched core";
let arguments = (ins
PimTensor:$outputBuffer,
I32Attr:$size,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let hasCustomAssemblyFormat = 1;
}
def PimReceiveTensorBatchOp : PimOp<"receive_tensor_batch", [DestinationStyleOpInterface]> {
let summary = "Receive equal contiguous chunks into one per-lane tensor inside a batched core";
let arguments = (ins
PimTensor:$outputBuffer,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from host memory into device memory";
-226
View File
@@ -28,34 +28,6 @@ static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred,
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
}
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
printer << "(";
for (auto [index, argument] : llvm::enumerate(arguments)) {
if (index != 0)
printer << ", ";
printer.printOperand(argument);
}
printer << ")";
}
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (parser.parseLParen())
return failure();
if (succeeded(parser.parseOptionalRParen()))
return success();
OpAsmParser::Argument argument;
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
}
return parser.parseRParen();
}
static void
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
printCompressedValueList(printer, arguments, delimiter);
@@ -98,12 +70,6 @@ static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<i
printCompressedIntegerList(printer, coreIds);
}
static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keyword, SmallVectorImpl<int32_t>& coreIds) {
if (failed(parser.parseOptionalKeyword(keyword)))
return success();
return parseCompressedIntegerList(parser, coreIds);
}
} // namespace
void PimCoreOp::print(OpAsmPrinter& printer) {
@@ -295,198 +261,6 @@ ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) {
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
}
void PimSendBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendTensorBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult PimSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendTensorOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperand(input, inputType, result.operands);
}
void PimReceiveTensorOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printer.printOperand(getOutputBuffer());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutputBuffer().getType());
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand outputBuffer;
Type outputBufferType;
Type outputType;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|| parser.parseType(outputType))
return failure();
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void PimReceiveBatchOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printer.printOperand(getOutputBuffer());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutputBuffer().getType());
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand outputBuffer;
Type outputBufferType;
Type outputType;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|| parser.parseType(outputType))
return failure();
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void PimReceiveTensorBatchOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printer.printOperand(getOutputBuffer());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutputBuffer().getType());
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand outputBuffer;
Type outputBufferType;
Type outputType;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|| parser.parseType(outputType))
return failure();
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void PimConcatOp::print(OpAsmPrinter& printer) {
printer << " axis " << getAxis() << " ";
printCompressedValueSequence(printer, getInputs());
-75
View File
@@ -90,56 +90,6 @@ static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type r
return success();
}
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor or memref";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(coreIds.size()) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids";
return success();
}
static LogicalResult
verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
if (!coreBatchOp)
return op->emitError() << kind << " must be nested inside pim.core_batch";
int32_t laneCount = coreBatchOp.getLaneCount();
if (laneCount <= 0)
return op->emitError() << kind << " requires a positive parent laneCount";
if (coreIds.size() % static_cast<size_t>(laneCount) != 0)
return op->emitError() << kind << " core id count must be divisible by the parent laneCount";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor or memref";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
auto shapedType = dyn_cast<ShapedType>(weight.getType());
if (!shapedType)
@@ -177,31 +127,6 @@ LogicalResult PimCoreBatchOp::verify() {
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
}
LogicalResult PimSendTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
}
LogicalResult PimSendTensorBatchOp::verify() {
return verifyTensorBatchCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor_batch");
}
LogicalResult PimReceiveTensorOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
}
LogicalResult PimReceiveTensorBatchOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
return verifyTensorBatchCommunication(
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
}
LogicalResult PimVMMOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
@@ -157,72 +157,6 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
}
};
struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveBatchOpInterface, PimReceiveBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveBatchOp>(op);
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveBatchOp>(rewriter,
op,
outputBufferOpt->getType(),
*outputBufferOpt,
receiveOp.getSizeAttr(),
receiveOp.getSourceCoreIdsAttr());
return success();
}
};
struct ReceiveTensorOpInterface
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorOpInterface, PimReceiveTensorOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveTensorOp>(op);
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveTensorOp>(
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
return success();
}
};
struct ReceiveTensorBatchOpInterface
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorBatchOpInterface, PimReceiveTensorBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveTensorBatchOp>(op);
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveTensorBatchOp>(
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
return success();
}
};
struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInterface, PimConcatOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -252,30 +186,6 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
}
};
struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel<SendTensorOpInterface, PimSendTensorOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto sendOp = cast<PimSendTensorOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
replaceOpWithNewBufferizedOp<PimSendTensorOp>(
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
return success();
}
};
struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface, PimSendOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
@@ -303,58 +213,6 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
}
};
struct SendBatchOpInterface : BufferizableOpInterface::ExternalModel<SendBatchOpInterface, PimSendBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto sendOp = cast<PimSendBatchOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
replaceOpWithNewBufferizedOp<PimSendBatchOp>(rewriter,
op,
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
sendOp.getSizeAttr(),
sendOp.getTargetCoreIdsAttr());
return success();
}
};
struct SendTensorBatchOpInterface
: BufferizableOpInterface::ExternalModel<SendTensorBatchOpInterface, PimSendTensorBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto sendOp = cast<PimSendTensorBatchOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
replaceOpWithNewBufferizedOp<PimSendTensorBatchOp>(
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
return success();
}
};
struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface, PimCoreOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
@@ -699,13 +557,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimCoreOp::attachInterface<CoreOpInterface>(*ctx);
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx);
PimReceiveTensorBatchOp::attachInterface<ReceiveTensorBatchOpInterface>(*ctx);
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
PimSendOp::attachInterface<SendOpInterface>(*ctx);
PimSendBatchOp::attachInterface<SendBatchOpInterface>(*ctx);
PimSendTensorBatchOp::attachInterface<SendTensorBatchOpInterface>(*ctx);
PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx);
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
-105
View File
@@ -194,111 +194,6 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
}];
}
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> {
let summary = "Send equal contiguous chunks of one tensor through logical channels";
let arguments = (ins
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds,
SpatTensor:$input
);
let hasVerifier = 1;
let assemblyFormat = [{
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
}];
}
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> {
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
let arguments = (ins
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
}];
}
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> {
let summary = "Send per-lane tensors through logical channels in a batch body";
let arguments = (ins
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds,
SpatTensor:$input
);
let hasVerifier = 1;
let assemblyFormat = [{
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
}];
}
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> {
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds,
SpatTensor:$input
);
let hasVerifier = 1;
let assemblyFormat = [{
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
}];
}
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> {
let summary = "Receive a per-lane tensor through logical channels in a batch body";
let arguments = (ins
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
}];
}
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> {
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
}];
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
@@ -95,13 +95,6 @@ static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
return shapedType.getShape();
}
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
auto batchOp = op->getParentOfType<SpatComputeBatch>();
if (!batchOp)
return failure();
return batchOp.getLaneCount();
}
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
if (batchOp.getNumResults() == 0)
return false;
@@ -233,68 +226,6 @@ static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::Paralle
return success();
}
static LogicalResult verifyTensorChannelSizes(
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
if (channelCount == 0)
return op->emitError() << kind << " must carry at least one chunk";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(channelCount) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
return success();
}
static LogicalResult
verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelCount != static_cast<size_t>(*laneCount))
return op->emitError("channel metadata length must match parent laneCount");
return success();
}
static LogicalResult verifyTensorBatchChannelSizes(
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
return success();
}
static Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
@@ -564,52 +495,6 @@ LogicalResult SpatCompute::verify() {
return success();
}
LogicalResult SpatChannelSendTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(),
getInput().getType(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_send_tensor");
}
LogicalResult SpatChannelReceiveTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(),
getOutput().getType(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_receive_tensor");
}
LogicalResult SpatChannelSendBatchOp::verify() {
return verifyBatchChannelSizes(
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
}
LogicalResult SpatChannelSendTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(),
getInput().getType(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_send_tensor_batch");
}
LogicalResult SpatChannelReceiveBatchOp::verify() {
return verifyBatchChannelSizes(
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
}
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(),
getOutput().getType(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_receive_tensor_batch");
}
LogicalResult SpatComputeBatch::verify() {
int32_t count = getLaneCount();
if (count <= 0)
@@ -79,8 +79,6 @@ struct MergeIrCounts {
uint64_t topLevelComputeBatchCount = 0;
uint64_t scalarChannelSendCount = 0;
uint64_t scalarChannelReceiveCount = 0;
uint64_t tensorChannelSendCount = 0;
uint64_t tensorChannelReceiveCount = 0;
uint64_t wvmmCount = 0;
uint64_t vaddCount = 0;
uint64_t scfForCount = 0;
@@ -95,10 +93,6 @@ MergeIrCounts collectMergeIrCounts(func::FuncOp funcOp) {
++counts.scalarChannelSendCount;
else if (isa<spatial::SpatChannelReceiveOp>(nestedOp))
++counts.scalarChannelReceiveCount;
else if (isa<spatial::SpatChannelSendTensorOp, spatial::SpatChannelSendTensorBatchOp>(nestedOp))
++counts.tensorChannelSendCount;
else if (isa<spatial::SpatChannelReceiveTensorOp, spatial::SpatChannelReceiveTensorBatchOp>(nestedOp))
++counts.tensorChannelReceiveCount;
else if (isa<spatial::SpatVMMOp>(nestedOp))
++counts.wvmmCount;
else if (isa<spatial::SpatVAddOp>(nestedOp))
@@ -130,9 +124,8 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
<< " scalar_send=" << counts.scalarChannelSendCount
<< " scalar_recv=" << counts.scalarChannelReceiveCount
<< " tensor_send=" << counts.tensorChannelSendCount
<< " tensor_recv=" << counts.tensorChannelReceiveCount << " wvmm=" << counts.wvmmCount
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
<< " wvmm=" << counts.wvmmCount << " vadd=" << counts.vaddCount
<< " scf_for=" << counts.scfForCount << "\n";
}
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
@@ -150,13 +150,7 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
pim::PimMemCopyDevToHostOp,
pim::PimMemCopyOp,
pim::PimReceiveOp,
pim::PimReceiveBatchOp,
pim::PimReceiveTensorOp,
pim::PimReceiveTensorBatchOp,
pim::PimSendOp,
pim::PimSendBatchOp,
pim::PimSendTensorOp,
pim::PimSendTensorBatchOp,
pim::PimConcatOp,
pim::PimVMMOp,
pim::PimTransposeOp,
@@ -173,18 +167,6 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
memref::GetGlobalOp>(op);
}
static FailureOr<ShapedType> getStaticByteSizedShapedType(Type type) {
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return failure();
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return failure();
return shapedType;
}
static LogicalResult verifyBatchOpSemantics(Operation& op,
const StaticValueKnowledge& knowledge,
pim::CappedDiagnosticReporter& diagnostics) {
@@ -203,73 +185,6 @@ static LogicalResult verifyBatchOpSemantics(Operation& op,
return success(!hasFailure);
}
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
if (sendBatchOp.getTargetCoreIds().size() != static_cast<size_t>(sendBatchOp->getParentOfType<pim::PimCoreBatchOp>()
.getLaneCount())) {
reportFailure([](Operation* illegalOp) {
illegalOp->emitOpError("targetCoreIds size must match parent laneCount");
});
}
return success(!hasFailure);
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
if (receiveBatchOp.getSourceCoreIds().size()
!= static_cast<size_t>(receiveBatchOp->getParentOfType<pim::PimCoreBatchOp>().getLaneCount())) {
reportFailure([](Operation* illegalOp) {
illegalOp->emitOpError("sourceCoreIds size must match parent laneCount");
});
}
return success(!hasFailure);
}
auto verifyTensorBatchCommunication = [&](Value tensorValue, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty()) {
reportFailure([&](Operation* illegalOp) { illegalOp->emitOpError() << kind << " must carry at least one chunk"; });
return;
}
auto parentBatchOp = op.getParentOfType<pim::PimCoreBatchOp>();
int32_t laneCount = parentBatchOp.getLaneCount();
if (laneCount <= 0) {
reportFailure([&](Operation* illegalOp) {
illegalOp->emitOpError() << kind << " requires a positive parent laneCount";
});
return;
}
if (coreIds.size() % static_cast<size_t>(laneCount) != 0) {
reportFailure([&](Operation* illegalOp) {
illegalOp->emitOpError() << kind << " core id count must be divisible by the parent laneCount";
});
return;
}
auto shapedType = getStaticByteSizedShapedType(tensorValue.getType());
if (failed(shapedType)) {
reportFailure([&](Operation* illegalOp) {
illegalOp->emitOpError() << kind << " requires a static shaped tensor or memref with byte-sized elements";
});
return;
}
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
int64_t totalBytes = (*shapedType).getNumElements() * (*shapedType).getElementTypeBitWidth() / 8;
if (totalBytes % chunkCount != 0) {
reportFailure([&](Operation* illegalOp) {
illegalOp->emitOpError() << kind << " tensor byte size must be divisible by the chunk count per lane";
});
}
};
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op))
verifyTensorBatchCommunication(sendTensorBatchOp.getInput(),
sendTensorBatchOp.getTargetCoreIds(),
"send_tensor_batch");
else if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op))
verifyTensorBatchCommunication(receiveTensorBatchOp.getOutput(),
receiveTensorBatchOp.getSourceCoreIds(),
"receive_tensor_batch");
return success(!hasFailure);
}