cleanup unused channel operations and related logic
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -18,27 +18,6 @@ using namespace onnx_mlir::pim;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
|
||||
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||
SmallVector<int32_t> constants;
|
||||
constants.reserve(values.size());
|
||||
for (Value value : values) {
|
||||
FailureOr<int32_t> constantValue = getConstantI32Value(value);
|
||||
if (failed(constantValue))
|
||||
return failure();
|
||||
constants.push_back(*constantValue);
|
||||
}
|
||||
return constants;
|
||||
}
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 2;
|
||||
@@ -62,43 +41,6 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
static LogicalResult lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendTensorBatchOp.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
return sendTensorBatchOp.emitOpError("expected constant targetCoreIds");
|
||||
for (int32_t& targetCoreId : *targetCoreIds)
|
||||
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
|
||||
|
||||
pim::PimSendTensorBatchOp::create(rewriter,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorBatchOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveTensorBatchOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
|
||||
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
|
||||
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
mapper.map(receiveTensorBatchOp.getOutput(), received);
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
||||
if (!result.hasOneUse())
|
||||
return failure();
|
||||
@@ -304,51 +246,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
return sendBatchOp.emitOpError("expected constant targetCoreIds");
|
||||
for (int32_t& targetCoreId : *targetCoreIds)
|
||||
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
|
||||
pim::PimSendBatchOp::create(rewriter,
|
||||
loc,
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
|
||||
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
|
||||
if (failed(lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter)))
|
||||
return failure();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveBatchOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveBatchOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
||||
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
mapper.map(receiveBatchOp.getOutput(), received);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
|
||||
if (failed(lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter)))
|
||||
return failure();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
||||
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
@@ -11,20 +9,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||
SmallVector<int32_t> constants;
|
||||
constants.reserve(values.size());
|
||||
for (Value value : values) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
|
||||
}
|
||||
return constants;
|
||||
}
|
||||
|
||||
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -59,42 +43,6 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(op.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
return rewriter.notifyMatchFailure(op, "expected constant targetCoreIds");
|
||||
for (int32_t& targetCoreId : *targetCoreIds)
|
||||
targetCoreId = toPimCoreId(targetCoreId);
|
||||
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(op.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return rewriter.notifyMatchFailure(op, "expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = toPimCoreId(sourceCoreId);
|
||||
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
||||
Value outputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||
Value received =
|
||||
pim::PimReceiveTensorOp::create(
|
||||
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -137,12 +85,7 @@ struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
|
||||
} // namespace
|
||||
|
||||
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<ChannelSendLowering,
|
||||
ChannelReceiveLowering,
|
||||
ChannelSendTensorLowering,
|
||||
ChannelReceiveTensorLowering,
|
||||
ExtractRowsLowering,
|
||||
ConcatLowering>(patterns.getContext());
|
||||
patterns.add<ChannelSendLowering, ChannelReceiveLowering, ExtractRowsLowering, ConcatLowering>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -53,20 +53,6 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
|
||||
}
|
||||
}
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||
SmallVector<int32_t> constants;
|
||||
constants.reserve(values.size());
|
||||
for (Value value : values) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
|
||||
}
|
||||
return constants;
|
||||
}
|
||||
|
||||
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
||||
@@ -186,25 +172,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
continue;
|
||||
}
|
||||
|
||||
auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp());
|
||||
if (receiveTensorOp && !blockArg->use_empty()) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveTensorOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg->getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType);
|
||||
Value received = PimReceiveTensorOp::create(rewriter,
|
||||
receiveTensorOp.getLoc(),
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
blockArg->replaceAllUsesWith(received);
|
||||
markOpToRemove(receiveTensorOp);
|
||||
}
|
||||
}
|
||||
|
||||
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
||||
|
||||
@@ -607,8 +607,6 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
|
||||
markOpToRemove(receiveTensorOp);
|
||||
};
|
||||
|
||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||
|
||||
@@ -156,11 +156,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
BuiltinDialect>();
|
||||
target.addLegalOp<spatial::SpatConcatOp,
|
||||
spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatChannelReceiveTensorOp,
|
||||
spatial::SpatChannelReceiveTensorBatchOp,
|
||||
spatial::SpatChannelSendOp,
|
||||
spatial::SpatChannelSendTensorOp,
|
||||
spatial::SpatChannelSendTensorBatchOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
RewritePatternSet initialPatterns(ctx);
|
||||
@@ -234,11 +230,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
BuiltinDialect>();
|
||||
coreBodyTarget.addLegalOp<spatial::SpatConcatOp,
|
||||
spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatChannelReceiveTensorOp,
|
||||
spatial::SpatChannelReceiveTensorBatchOp,
|
||||
spatial::SpatChannelSendOp,
|
||||
spatial::SpatChannelSendTensorOp,
|
||||
spatial::SpatChannelSendTensorBatchOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
SmallVector<pim::PimCoreOp> coreOps;
|
||||
@@ -282,9 +274,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
communicationTarget.addLegalOp<ModuleOp>();
|
||||
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
|
||||
spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatChannelReceiveTensorOp,
|
||||
spatial::SpatChannelSendOp,
|
||||
spatial::SpatChannelSendTensorOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
RewritePatternSet communicationPatterns(ctx);
|
||||
|
||||
Reference in New Issue
Block a user