sightly better bufferization

minor fixes
This commit is contained in:
NiccoloN
2026-05-07 17:53:47 +02:00
parent f2fe147961
commit f6c8cc4aa5
19 changed files with 150 additions and 141 deletions

View File

@@ -4,6 +4,7 @@
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -227,7 +228,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
continue; continue;
} }
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp)) if (mlir::isa<onnx_mlir::pim::PimEmptyManyOp, mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress {value, byteOffset}; return ResolvedContiguousAddress {value, byteOffset};
return mlir::failure(); return mlir::failure();

View File

@@ -54,7 +54,7 @@ bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
if (!computeOp || operandIndex >= computeOp.getWeights().size()) if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false; return false;
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex); return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
} }
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) { bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {

View File

@@ -97,6 +97,11 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>()) if (!allocOp->getParentOfType<pim::PimCoreOp>())
gatherMemEntry(allocOp.getResult()); gatherMemEntry(allocOp.getResult());
}); });
funcOp.walk([&](pim::PimEmptyManyOp emptyManyOp) {
if (!emptyManyOp->getParentOfType<pim::PimCoreOp>() && !emptyManyOp->getParentOfType<pim::PimCoreBatchOp>())
for (mlir::Value output : emptyManyOp.getOutputs())
gatherMemEntry(output);
});
allocateGatheredMemory(); allocateGatheredMemory();
@@ -106,6 +111,10 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
void PimMemory::allocateCore(Operation* op) { void PimMemory::allocateCore(Operation* op) {
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); }); op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
op->walk([&](pim::PimEmptyManyOp emptyManyOp) {
for (mlir::Value output : emptyManyOp.getOutputs())
gatherMemEntry(output);
});
allocateGatheredMemory(); allocateGatheredMemory();
} }
@@ -957,6 +966,8 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge); coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op)) else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge); coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
else if (isa<pim::PimEmptyManyOp>(op))
return success();
else { else {
op.emitError("Unsupported codegen for this operation"); op.emitError("Unsupported codegen for this operation");
op.dump(); op.dump();

View File

@@ -381,7 +381,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
vmmOutputs.reserve(aHSlicesArgs.size()); vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
vmmOutputs.push_back( vmmOutputs.push_back(
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg)); spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
if (vmmOutputs.empty()) { if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure(); return failure();
@@ -527,7 +527,7 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc)); &batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
rewriter.setInsertionPointToEnd(body); rewriter.setInsertionPointToEnd(body);
Value vmmResult = spatial::SpatWeightedVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult(); Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
Value laneResult = vmmResult; Value laneResult = vmmResult;
if (sharedBias) if (sharedBias)
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult(); laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();

View File

@@ -95,7 +95,7 @@ bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
return false; return false;
} }
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) { mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) {
assert("Only support operations with a single result" && operation->getNumResults() == 1); assert("Only support operations with a single result" && operation->getNumResults() == 1);
mlir::Value result = operation->getResult(0); mlir::Value result = operation->getResult(0);
auto resultType = result.getType(); auto resultType = result.getType();

View File

@@ -41,7 +41,7 @@ mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation); mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation); mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
inline mlir::tensor::EmptyOp inline mlir::tensor::EmptyOp
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) { createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {

View File

@@ -16,13 +16,13 @@ def onnxToPimTranspose : Pat<
>; >;
def spatToPimVMM : Pat< def spatToPimVMM : Pat<
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector), (SpatVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector, (PimVMMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatToPimMVM : Pat< def spatToPimMVM : Pat<
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector), (SpatMVMOp:$srcOpRes $weightIndex, $vector),
(PimMVMOp $weightIndex, $vector, (PimMVMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;

View File

@@ -252,25 +252,6 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
rewriter.replaceOp(concatOp, concatenated); rewriter.replaceOp(concatOp, concatenated);
} }
static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatWeightedVMMOp> wvmmOps;
funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) {
if (wvmmOp->getParentOfType<pim::PimCoreOp>() || wvmmOp->getParentOfType<pim::PimCoreBatchOp>())
wvmmOps.push_back(wvmmOp);
});
for (auto wvmmOp : wvmmOps) {
rewriter.setInsertionPoint(wvmmOp);
auto outputType = cast<ShapedType>(wvmmOp.getOutput().getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult();
rewriter.replaceOpWithNewOp<pim::PimVMMOp>(wvmmOp,
wvmmOp.getOutput().getType(),
rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()),
wvmmOp.getInput(),
outputBuffer);
}
}
static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) { static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatMapOp> mapOps; SmallVector<spatial::SpatMapOp> mapOps;
funcOp.walk([&](spatial::SpatMapOp mapOp) { funcOp.walk([&](spatial::SpatMapOp mapOp) {
@@ -736,7 +717,7 @@ void SpatialToPimPass::runOnOperation() {
SmallVector<pim::PimCoreOp> coreOps; SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) { for (auto coreOp : coreOps) {
if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) { if (failed(applyPartialConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -745,15 +726,13 @@ void SpatialToPimPass::runOnOperation() {
SmallVector<pim::PimCoreBatchOp> coreBatchOps; SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) { for (auto coreBatchOp : coreBatchOps) {
if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) { if (failed(applyPartialConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure(); signalPassFailure();
return; return;
} }
} }
} }
lowerRemainingSpatialMathOps(funcOp, rewriter);
RewritePatternSet channelPatterns(ctx); RewritePatternSet channelPatterns(ctx);
populateWithGenerated(channelPatterns); populateWithGenerated(channelPatterns);
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) { if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {

View File

@@ -96,7 +96,7 @@ def PimEmptyManyOp : PimOp<"empty_many", []> {
let summary = "Create many identical empty tensors"; let summary = "Create many identical empty tensors";
let results = (outs let results = (outs
Variadic<AnyRankedTensor>:$outputs Variadic<PimTensor>:$outputs
); );
let hasVerifier = 1; let hasVerifier = 1;

View File

@@ -79,9 +79,9 @@ LogicalResult PimEmptyManyOp::verify() {
return emitError("must produce at least one output"); return emitError("must produce at least one output");
Type firstType = getOutputs().front().getType(); Type firstType = getOutputs().front().getType();
auto firstTensorType = dyn_cast<RankedTensorType>(firstType); auto firstShapedType = dyn_cast<ShapedType>(firstType);
if (!firstTensorType) if (!firstShapedType || !firstShapedType.hasRank())
return emitError("outputs must all be ranked tensor types"); return emitError("outputs must all be ranked shaped types");
for (Value output : getOutputs().drop_front()) for (Value output : getOutputs().drop_front())
if (output.getType() != firstType) if (output.getType() != firstType)

View File

@@ -34,6 +34,15 @@ static Value materializeContiguousMemRef(Value memrefValue, Location loc, Rewrit
.getOutput(); .getOutput();
} }
static FailureOr<Value> getBufferOrValue(RewriterBase& rewriter,
Value value,
const BufferizationOptions& options,
BufferizationState& state) {
if (isa<BufferLikeType>(value.getType()))
return value;
return getBuffer(rewriter, value, options, state);
}
struct MemCopyHostToDevOpInterface struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> { : DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op, LogicalResult bufferize(Operation* op,
@@ -44,12 +53,12 @@ struct MemCopyHostToDevOpInterface
auto deviceTarget = memCopyHostToDevOp.getDeviceTarget(); auto deviceTarget = memCopyHostToDevOp.getDeviceTarget();
auto hostSource = memCopyHostToDevOp.getHostSource(); auto hostSource = memCopyHostToDevOp.getHostSource();
auto deviceTargetOpt = getBuffer(rewriter, deviceTarget, options, state); auto deviceTargetOpt = getBufferOrValue(rewriter, deviceTarget, options, state);
if (failed(deviceTargetOpt)) if (failed(deviceTargetOpt))
return failure(); return failure();
auto deviceTargetMemRef = *deviceTargetOpt; auto deviceTargetMemRef = *deviceTargetOpt;
auto hostSourceOpt = getBuffer(rewriter, hostSource, options, state); auto hostSourceOpt = getBufferOrValue(rewriter, hostSource, options, state);
if (failed(hostSourceOpt)) if (failed(hostSourceOpt))
return failure(); return failure();
auto hostSourceMemRef = *hostSourceOpt; auto hostSourceMemRef = *hostSourceOpt;
@@ -73,10 +82,10 @@ struct MemCopyHostToDevBatchOpInterface
const BufferizationOptions& options, const BufferizationOptions& options,
BufferizationState& state) const { BufferizationState& state) const {
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevBatchOp>(op); auto memCopyHostToDevOp = cast<PimMemCopyHostToDevBatchOp>(op);
auto deviceTargetOpt = getBuffer(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state); auto deviceTargetOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state);
if (failed(deviceTargetOpt)) if (failed(deviceTargetOpt))
return failure(); return failure();
auto hostSourceOpt = getBuffer(rewriter, memCopyHostToDevOp.getHostSource(), options, state); auto hostSourceOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getHostSource(), options, state);
if (failed(hostSourceOpt)) if (failed(hostSourceOpt))
return failure(); return failure();
@@ -101,13 +110,13 @@ struct MemCopyDevToHostOpInterface
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op); auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
auto hostTarget = memCopyDevToHostOp.getHostTarget(); auto hostTarget = memCopyDevToHostOp.getHostTarget();
auto hostTargetOpt = getBuffer(rewriter, hostTarget, options, state); auto hostTargetOpt = getBufferOrValue(rewriter, hostTarget, options, state);
if (failed(hostTargetOpt)) if (failed(hostTargetOpt))
return failure(); return failure();
auto hostTargetMemRef = *hostTargetOpt; auto hostTargetMemRef = *hostTargetOpt;
auto deviceSource = memCopyDevToHostOp.getDeviceSource(); auto deviceSource = memCopyDevToHostOp.getDeviceSource();
auto deviceSourceOpt = getBuffer(rewriter, deviceSource, options, state); auto deviceSourceOpt = getBufferOrValue(rewriter, deviceSource, options, state);
if (failed(deviceSourceOpt)) if (failed(deviceSourceOpt))
return failure(); return failure();
auto deviceSourceMemRef = *deviceSourceOpt; auto deviceSourceMemRef = *deviceSourceOpt;
@@ -135,7 +144,7 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
BufferizationState& state) const { BufferizationState& state) const {
auto receiveOp = cast<PimReceiveOp>(op); auto receiveOp = cast<PimReceiveOp>(op);
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state); auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
@@ -159,7 +168,7 @@ struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<Receive
const BufferizationOptions& options, const BufferizationOptions& options,
BufferizationState& state) const { BufferizationState& state) const {
auto receiveOp = cast<PimReceiveBatchOp>(op); auto receiveOp = cast<PimReceiveBatchOp>(op);
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state); auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
@@ -185,13 +194,11 @@ struct ReceiveManyOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveM
auto receiveOp = cast<PimReceiveManyOp>(op); auto receiveOp = cast<PimReceiveManyOp>(op);
SmallVector<Value> outputBuffers; SmallVector<Value> outputBuffers;
SmallVector<Type> resultTypes; SmallVector<Type> resultTypes;
SmallVector<Value> tensorResults;
outputBuffers.reserve(receiveOp.getOutputBuffers().size()); outputBuffers.reserve(receiveOp.getOutputBuffers().size());
resultTypes.reserve(receiveOp.getOutputBuffers().size()); resultTypes.reserve(receiveOp.getOutputBuffers().size());
tensorResults.reserve(receiveOp.getOutputBuffers().size());
for (Value outputBuffer : receiveOp.getOutputBuffers()) { for (Value outputBuffer : receiveOp.getOutputBuffers()) {
auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state); auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state);
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
outputBuffers.push_back(*outputBufferOpt); outputBuffers.push_back(*outputBufferOpt);
@@ -200,15 +207,7 @@ struct ReceiveManyOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveM
auto newOp = PimReceiveManyOp::create( auto newOp = PimReceiveManyOp::create(
rewriter, receiveOp.getLoc(), TypeRange(resultTypes), ValueRange(outputBuffers), receiveOp.getSourceCoreIdsAttr()); rewriter, receiveOp.getLoc(), TypeRange(resultTypes), ValueRange(outputBuffers), receiveOp.getSourceCoreIdsAttr());
rewriter.replaceOp(receiveOp, newOp.getOutputs());
for (auto [bufferResult, tensorResult] : llvm::zip(newOp.getOutputs(), receiveOp.getOutputs())) {
auto tensorType = cast<RankedTensorType>(tensorResult.getType());
auto toTensor =
bufferization::ToTensorOp::create(rewriter, receiveOp.getLoc(), tensorType, bufferResult, UnitAttr(), UnitAttr());
tensorResults.push_back(toTensor.getResult());
}
rewriter.replaceOp(receiveOp, tensorResults);
return success(); return success();
} }
}; };
@@ -226,13 +225,11 @@ struct ReceiveManyBatchOpInterface
auto receiveOp = cast<PimReceiveManyBatchOp>(op); auto receiveOp = cast<PimReceiveManyBatchOp>(op);
SmallVector<Value> outputBuffers; SmallVector<Value> outputBuffers;
SmallVector<Type> resultTypes; SmallVector<Type> resultTypes;
SmallVector<Value> tensorResults;
outputBuffers.reserve(receiveOp.getOutputBuffers().size()); outputBuffers.reserve(receiveOp.getOutputBuffers().size());
resultTypes.reserve(receiveOp.getOutputBuffers().size()); resultTypes.reserve(receiveOp.getOutputBuffers().size());
tensorResults.reserve(receiveOp.getOutputBuffers().size());
for (Value outputBuffer : receiveOp.getOutputBuffers()) { for (Value outputBuffer : receiveOp.getOutputBuffers()) {
auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state); auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state);
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
outputBuffers.push_back(*outputBufferOpt); outputBuffers.push_back(*outputBufferOpt);
@@ -244,15 +241,7 @@ struct ReceiveManyBatchOpInterface
TypeRange(resultTypes), TypeRange(resultTypes),
ValueRange(outputBuffers), ValueRange(outputBuffers),
receiveOp.getSourceCoreIdsAttr()); receiveOp.getSourceCoreIdsAttr());
rewriter.replaceOp(receiveOp, newOp.getOutputs());
for (auto [bufferResult, tensorResult] : llvm::zip(newOp.getOutputs(), receiveOp.getOutputs())) {
auto tensorType = cast<RankedTensorType>(tensorResult.getType());
auto toTensor =
bufferization::ToTensorOp::create(rewriter, receiveOp.getLoc(), tensorType, bufferResult, UnitAttr(), UnitAttr());
tensorResults.push_back(toTensor.getResult());
}
rewriter.replaceOp(receiveOp, tensorResults);
return success(); return success();
} }
}; };
@@ -267,7 +256,7 @@ struct ExtractRowsOpInterface : DstBufferizableOpInterfaceExternalModel<ExtractR
const BufferizationOptions& options, const BufferizationOptions& options,
BufferizationState& state) const { BufferizationState& state) const {
auto extractRowsOp = cast<PimExtractRowsOp>(op); auto extractRowsOp = cast<PimExtractRowsOp>(op);
auto inputOpt = getBuffer(rewriter, extractRowsOp.getInput(), options, state); auto inputOpt = getBufferOrValue(rewriter, extractRowsOp.getInput(), options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
@@ -277,7 +266,7 @@ struct ExtractRowsOpInterface : DstBufferizableOpInterfaceExternalModel<ExtractR
resultTypes.reserve(extractRowsOp.getOutputBuffers().size()); resultTypes.reserve(extractRowsOp.getOutputBuffers().size());
for (Value outputBuffer : extractRowsOp.getOutputBuffers()) { for (Value outputBuffer : extractRowsOp.getOutputBuffers()) {
auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state); auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state);
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
outputBuffers.push_back(*outputBufferOpt); outputBuffers.push_back(*outputBufferOpt);
@@ -307,13 +296,13 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
SmallVector<Value> inputs; SmallVector<Value> inputs;
inputs.reserve(concatOp.getInputs().size()); inputs.reserve(concatOp.getInputs().size());
for (Value input : concatOp.getInputs()) { for (Value input : concatOp.getInputs()) {
auto inputOpt = getBuffer(rewriter, input, options, state); auto inputOpt = getBufferOrValue(rewriter, input, options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
inputs.push_back(materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter)); inputs.push_back(materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter));
} }
auto outputBufferOpt = getBuffer(rewriter, concatOp.getOutputBuffer(), options, state); auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
@@ -323,6 +312,31 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
} }
}; };
struct EmptyManyOpInterface : BufferizableOpInterface::ExternalModel<EmptyManyOpInterface, PimEmptyManyOp> {
bool bufferizesToAllocation(Operation* op, Value value) const { return true; }
bool resultBufferizesToMemoryWrite(Operation* op, OpResult opResult, const AnalysisState& state) const {
return false;
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto emptyManyOp = cast<PimEmptyManyOp>(op);
SmallVector<Type> resultTypes;
resultTypes.reserve(emptyManyOp.getOutputs().size());
for (Value output : emptyManyOp.getOutputs()) {
auto shapedType = cast<ShapedType>(output.getType());
resultTypes.push_back(MemRefType::get(shapedType.getShape(), shapedType.getElementType()));
}
replaceOpWithNewBufferizedOp<PimEmptyManyOp>(rewriter, emptyManyOp, TypeRange(resultTypes));
return success();
}
};
struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, PimMapOp> { struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, PimMapOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
@@ -375,7 +389,7 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, P
for (Value input : mapOp.getInputs()) { for (Value input : mapOp.getInputs()) {
if (isa<TensorType>(input.getType())) { if (isa<TensorType>(input.getType())) {
auto inputOpt = getBuffer(rewriter, input, options, state); auto inputOpt = getBufferOrValue(rewriter, input, options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
inputs.push_back(*inputOpt); inputs.push_back(*inputOpt);
@@ -453,6 +467,16 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
BufferizationState& state) const { BufferizationState& state) const {
auto coreBatchOp = cast<PimCoreBatchOp>(op); auto coreBatchOp = cast<PimCoreBatchOp>(op);
bool alreadyBufferized = llvm::all_of(coreBatchOp.getWeights(), [](Value weight) {
return isa<BufferLikeType>(weight.getType());
}) && llvm::all_of(coreBatchOp.getInputs(), [](Value input) {
return isa<BufferLikeType>(input.getType());
}) && llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) {
return isa<BufferLikeType>(arg.getType());
});
if (alreadyBufferized)
return success();
SmallVector<Value> weights; SmallVector<Value> weights;
SmallVector<Value> inputs; SmallVector<Value> inputs;
weights.reserve(coreBatchOp.getWeights().size()); weights.reserve(coreBatchOp.getWeights().size());
@@ -460,7 +484,7 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
for (Value weight : coreBatchOp.getWeights()) { for (Value weight : coreBatchOp.getWeights()) {
if (isa<TensorType>(weight.getType())) { if (isa<TensorType>(weight.getType())) {
auto weightOpt = getBuffer(rewriter, weight, options, state); auto weightOpt = getBufferOrValue(rewriter, weight, options, state);
if (failed(weightOpt)) if (failed(weightOpt))
return failure(); return failure();
weights.push_back(*weightOpt); weights.push_back(*weightOpt);
@@ -472,7 +496,7 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
for (Value input : coreBatchOp.getInputs()) { for (Value input : coreBatchOp.getInputs()) {
if (isa<TensorType>(input.getType())) { if (isa<TensorType>(input.getType())) {
auto inputOpt = getBuffer(rewriter, input, options, state); auto inputOpt = getBufferOrValue(rewriter, input, options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
inputs.push_back(*inputOpt); inputs.push_back(*inputOpt);
@@ -510,11 +534,11 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
BufferizationState& state) const { BufferizationState& state) const {
auto transposeOp = cast<PimTransposeOp>(op); auto transposeOp = cast<PimTransposeOp>(op);
auto inputOpt = getBuffer(rewriter, transposeOp.getInput(), options, state); auto inputOpt = getBufferOrValue(rewriter, transposeOp.getInput(), options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
auto outputBufferOpt = getBuffer(rewriter, transposeOp.getOutputBuffer(), options, state); auto outputBufferOpt = getBufferOrValue(rewriter, transposeOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
@@ -547,11 +571,11 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
BufferizationState& state) const { BufferizationState& state) const {
auto vmmOp = cast<PimVMMOp>(op); auto vmmOp = cast<PimVMMOp>(op);
auto inputOpt = getBuffer(rewriter, vmmOp.getInput(), options, state); auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
auto outputBufferOpt = getBuffer(rewriter, vmmOp.getOutputBuffer(), options, state); auto outputBufferOpt = getBufferOrValue(rewriter, vmmOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
@@ -574,11 +598,11 @@ struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface,
BufferizationState& state) const { BufferizationState& state) const {
auto mvmOp = cast<PimMVMOp>(op); auto mvmOp = cast<PimMVMOp>(op);
auto inputOpt = getBuffer(rewriter, mvmOp.getInput(), options, state); auto inputOpt = getBufferOrValue(rewriter, mvmOp.getInput(), options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
auto outputBufferOpt = getBuffer(rewriter, mvmOp.getOutputBuffer(), options, state); auto outputBufferOpt = getBufferOrValue(rewriter, mvmOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
@@ -608,15 +632,15 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
BufferizationState& state) const { BufferizationState& state) const {
auto binaryOp = cast<OpTy>(op); auto binaryOp = cast<OpTy>(op);
auto lhsOpt = getBuffer(rewriter, binaryOp.getLhs(), options, state); auto lhsOpt = getBufferOrValue(rewriter, binaryOp.getLhs(), options, state);
if (failed(lhsOpt)) if (failed(lhsOpt))
return failure(); return failure();
auto rhsOpt = getBuffer(rewriter, binaryOp.getRhs(), options, state); auto rhsOpt = getBufferOrValue(rewriter, binaryOp.getRhs(), options, state);
if (failed(rhsOpt)) if (failed(rhsOpt))
return failure(); return failure();
auto outputBufferOpt = getBuffer(rewriter, binaryOp.getOutputBuffer(), options, state); auto outputBufferOpt = getBufferOrValue(rewriter, binaryOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
@@ -647,11 +671,11 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
BufferizationState& state) const { BufferizationState& state) const {
auto unaryOp = cast<OpTy>(op); auto unaryOp = cast<OpTy>(op);
auto inputOpt = getBuffer(rewriter, unaryOp.getInput(), options, state); auto inputOpt = getBufferOrValue(rewriter, unaryOp.getInput(), options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
auto outputBufferOpt = getBuffer(rewriter, unaryOp.getOutputBuffer(), options, state); auto outputBufferOpt = getBufferOrValue(rewriter, unaryOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
@@ -664,6 +688,7 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
void registerOpBufferizationInterfaces(DialectRegistry& registry) { void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) { registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimEmptyManyOp::attachInterface<EmptyManyOpInterface>(*ctx);
PimMapOp::attachInterface<MapOpInterface>(*ctx); PimMapOp::attachInterface<MapOpInterface>(*ctx);
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx); PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx); PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);

View File

@@ -47,37 +47,26 @@ private:
void PimBufferizationPass::runOnOperation() { void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation(); auto moduleOp = getOperation();
{
SmallVector<pim::PimEmptyManyOp> emptyManyOps;
moduleOp.walk([&](pim::PimEmptyManyOp emptyManyOp) { emptyManyOps.push_back(emptyManyOp); });
IRRewriter rewriter(moduleOp.getContext());
for (auto emptyManyOp : emptyManyOps) {
SmallVector<Value> replacementValues;
replacementValues.reserve(emptyManyOp.getOutputs().size());
rewriter.setInsertionPoint(emptyManyOp);
for (Value output : emptyManyOp.getOutputs()) {
auto outputType = cast<RankedTensorType>(output.getType());
replacementValues.push_back(
tensor::EmptyOp::create(rewriter, emptyManyOp.getLoc(), outputType.getShape(), outputType.getElementType()));
}
rewriter.replaceOp(emptyManyOp, replacementValues);
}
}
// Refactor this into a function // Refactor this into a function
{ {
auto funcOp = getPimEntryFunc(moduleOp); auto funcOp = *getPimEntryFunc(moduleOp);
auto coreOps = llvm::to_vector(funcOp->getOps<pim::PimCoreOp>()); SmallVector<Operation*> coreOps;
funcOp->walk<WalkOrder::PreOrder>([&](Operation* op) {
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
coreOps.push_back(op);
});
MLIRContext* ctx = moduleOp.getContext(); MLIRContext* ctx = moduleOp.getContext();
// failableParallelForEach will run the lambda in parallel and stop if any thread fails // failableParallelForEach will run the lambda in parallel and stop if any thread fails
LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](pim::PimCoreOp coreOp) { LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](Operation* coreOp) {
// Again, allocate state LOCALLY per thread/function // Again, allocate state LOCALLY per thread/function
bufferization::OneShotBufferizationOptions options; bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true; options.allowUnknownOps = true;
if (isa<pim::PimCoreBatchOp>(coreOp))
options.opFilter.denyOperation([coreOp](Operation* op) { return op == coreOp; });
bufferization::BufferizationState state; bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) { if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) {
coreOp.emitError("Failed to bufferize PIM and Spatial ops"); coreOp->emitError("Failed to bufferize PIM and Spatial ops");
return failure(); return failure();
} }
return success(); return success();
@@ -89,13 +78,16 @@ void PimBufferizationPass::runOnOperation() {
} }
funcOp->walk([&](bufferization::ToTensorOp toTensorOp) { funcOp->walk([&](bufferization::ToTensorOp toTensorOp) {
if (llvm::isa_and_present<pim::PimCoreOp>(toTensorOp->getParentOp())) if (llvm::isa_and_present<pim::PimCoreOp, pim::PimCoreBatchOp>(toTensorOp->getParentOp()))
toTensorOp->setAttr("restrict", UnitAttr::get(ctx)); toTensorOp->setAttr("restrict", UnitAttr::get(ctx));
}); });
// One-Shot-Bufferization // One-Shot-Bufferization
bufferization::OneShotBufferizationOptions options; bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true; options.allowUnknownOps = true;
options.opFilter.denyOperation([](Operation* op) {
return op->getParentOfType<pim::PimCoreOp>() || op->getParentOfType<pim::PimCoreBatchOp>();
});
bufferization::BufferizationState state; bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) { if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {

View File

@@ -253,7 +253,7 @@ def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> {
// Math // Math
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def SpatWeightedVMMOp : SpatOp<"wvmm", []> { def SpatVMMOp : SpatOp<"wvmm", []> {
let summary = "Vector-matrix multiplication within a weighted compute operation"; let summary = "Vector-matrix multiplication within a weighted compute operation";
let arguments = (ins let arguments = (ins
@@ -272,7 +272,7 @@ def SpatWeightedVMMOp : SpatOp<"wvmm", []> {
}]; }];
} }
def SpatWeightedMVMOp : SpatOp<"Wmvm", []> { def SpatMVMOp : SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a weighted compute operation"; let summary = "Matrix-vector multiplication within a weighted compute operation";
let arguments = (ins let arguments = (ins

View File

@@ -20,7 +20,7 @@ namespace spatial {
namespace { namespace {
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter, inline LogicalResult mvmOpVerifySize2(SpatMVMOp* emitter,
ArrayRef<int64_t>& matrixShape, ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape, ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) { ArrayRef<int64_t>& outputShape) {
@@ -45,7 +45,7 @@ inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
return success(); return success();
} }
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter, inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
ArrayRef<int64_t>& matrixShape, ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape, ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) { ArrayRef<int64_t>& outputShape) {
@@ -177,10 +177,10 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
} }
for (auto& bodyOp : block) { for (auto& bodyOp : block) {
if (auto wvmm = dyn_cast<SpatWeightedVMMOp>(&bodyOp)) if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp))
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane) if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane"); return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
if (auto wmvm = dyn_cast<SpatWeightedMVMOp>(&bodyOp)) if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane) if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane"); return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
} }
@@ -189,10 +189,10 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
} // namespace } // namespace
LogicalResult SpatWeightedMVMOp::verify() { LogicalResult SpatMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt)) if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op"); return emitError("SpatMVMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt; auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape(); auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape(); auto outputShape = getOutput().getType().getShape();
@@ -204,10 +204,10 @@ LogicalResult SpatWeightedMVMOp::verify() {
return emitError("matrix rank must be 2 or 4"); return emitError("matrix rank must be 2 or 4");
} }
LogicalResult SpatWeightedVMMOp::verify() { LogicalResult SpatVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt)) if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op"); return emitError("SpatVMMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt; auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape(); auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape(); auto outputShape = getOutput().getType().getShape();

View File

@@ -133,7 +133,7 @@ CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
CrossbarUsage crossbarUsage = 0; CrossbarUsage crossbarUsage = 0;
for (auto& block : body) for (auto& block : body)
for (auto& op : block) for (auto& op : block)
if (isa<SpatWeightedVMMOp>(op)) if (isa<SpatVMMOp>(op))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1)); crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
return crossbarUsage; return crossbarUsage;
} }

View File

@@ -105,7 +105,7 @@ inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute
CrossbarUsage crossbarUsage = 0; CrossbarUsage crossbarUsage = 0;
for (auto& region : spatCompute.getBody()) for (auto& region : spatCompute.getBody())
for (auto& inst : region) for (auto& inst : region)
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst)) if (llvm::isa<onnx_mlir::spatial::SpatVMMOp>(inst))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1)); crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
return crossbarUsage; return crossbarUsage;
} }

View File

@@ -838,9 +838,9 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
for (auto& op : child.getBody().front()) { for (auto& op : child.getBody().front()) {
auto newInst = rewriter.clone(op, mapper); auto newInst = rewriter.clone(op, mapper);
if (auto weightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) if (auto weightedMvmOp = dyn_cast<spatial::SpatMVMOp>(newInst))
remapWeightIndex(weightedMvmOp); remapWeightIndex(weightedMvmOp);
if (auto weightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) if (auto weightedVmmOp = dyn_cast<spatial::SpatVMMOp>(newInst))
remapWeightIndex(weightedVmmOp); remapWeightIndex(weightedVmmOp);
} }
@@ -884,9 +884,9 @@ void emitMotifProfile(func::FuncOp funcOp) {
ComputeMotifInfo& info = computeInfos[index]; ComputeMotifInfo& info = computeInfos[index];
for (Operation& op : compute.getBody().front()) { for (Operation& op : compute.getBody().front()) {
info.instructionCount++; info.instructionCount++;
if (isa<spatial::SpatWeightedMVMOp>(&op)) if (isa<spatial::SpatMVMOp>(&op))
info.weightedMvmCount++; info.weightedMvmCount++;
if (isa<spatial::SpatWeightedVMMOp>(&op)) if (isa<spatial::SpatVMMOp>(&op))
info.weightedVmmCount++; info.weightedVmmCount++;
} }
if (info.weightedVmmCount > 0) { if (info.weightedVmmCount > 0) {
@@ -1617,13 +1617,13 @@ public:
} }
Operation* clonedOp = cpuRewriter.clone(op, mapper); Operation* clonedOp = cpuRewriter.clone(op, mapper);
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(&op)) { if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
auto newWeightedMvmOp = cast<spatial::SpatWeightedMVMOp>(clonedOp); auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()]; Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight)); newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight));
} }
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&op)) { if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
auto newWeightedVmmOp = cast<spatial::SpatWeightedVMMOp>(clonedOp); auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()]; Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()];
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight)); newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
} }
@@ -1643,22 +1643,22 @@ public:
} }
Operation* clonedOp = cpuRewriter.clone(op, mapper); Operation* clonedOp = cpuRewriter.clone(op, mapper);
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(&op)) { if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
if (oldWeightedMvmOp.getWeightIndex() != 0) { if (oldWeightedMvmOp.getWeightIndex() != 0) {
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0"); task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
signalPassFailure(); signalPassFailure();
return; return;
} }
auto newWeightedMvmOp = cast<spatial::SpatWeightedMVMOp>(clonedOp); auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset])); newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
} }
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&op)) { if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
if (oldWeightedVmmOp.getWeightIndex() != 0) { if (oldWeightedVmmOp.getWeightIndex() != 0) {
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0"); task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
signalPassFailure(); signalPassFailure();
return; return;
} }
auto newWeightedVmmOp = cast<spatial::SpatWeightedVMMOp>(clonedOp); auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset])); newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
} }
} }

View File

@@ -55,7 +55,7 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); }); [](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
} }
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatWeightedVMMOp startOp) { static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
RegularChunk chunk; RegularChunk chunk;
chunk.startOp = startOp.getOperation(); chunk.startOp = startOp.getOperation();
chunk.input = startOp.getInput(); chunk.input = startOp.getInput();
@@ -376,7 +376,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
auto compactInBlock = [&](Block& block) { auto compactInBlock = [&](Block& block) {
for (auto it = block.begin(); it != block.end();) { for (auto it = block.begin(); it != block.end();) {
auto startOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it); auto startOp = dyn_cast<spatial::SpatVMMOp>(&*it);
if (!startOp) { if (!startOp) {
++it; ++it;
continue; continue;
@@ -391,7 +391,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
SmallVector<RegularChunk> run {*anchorChunk}; SmallVector<RegularChunk> run {*anchorChunk};
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size())); auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
while (runIt != block.end()) { while (runIt != block.end()) {
auto candidateStart = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt); auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
if (!candidateStart) if (!candidateStart)
break; break;
@@ -425,7 +425,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) { for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
Block& block = compute.getBody().front(); Block& block = compute.getBody().front();
for (auto it = block.begin(); it != block.end();) { for (auto it = block.begin(); it != block.end();) {
auto wvmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it); auto wvmmOp = dyn_cast<spatial::SpatVMMOp>(&*it);
if (!wvmmOp) { if (!wvmmOp) {
++it; ++it;
continue; continue;
@@ -440,11 +440,11 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
continue; continue;
} }
SmallVector<spatial::SpatWeightedVMMOp> run; SmallVector<spatial::SpatVMMOp> run;
auto runIt = it; auto runIt = it;
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber()); int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
while (runIt != block.end()) { while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt); auto current = dyn_cast<spatial::SpatVMMOp>(&*runIt);
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex() if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp || current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|| current.getInput().getType() != wvmmOp.getInput().getType() || current.getInput().getType() != wvmmOp.getInput().getType()
@@ -545,7 +545,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
extractOffsets, extractOffsets,
extractSizes, extractSizes,
extractStrides); extractStrides);
auto loopWvmm = spatial::SpatWeightedVMMOp::create( auto loopWvmm = spatial::SpatVMMOp::create(
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult()); rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};

View File

@@ -18,6 +18,7 @@ namespace {
static bool isAddressOnlyHostOp(Operation* op) { static bool isAddressOnlyHostOp(Operation* op) {
return isa<arith::ConstantOp, return isa<arith::ConstantOp,
pim::PimEmptyManyOp,
memref::AllocOp, memref::AllocOp,
memref::GetGlobalOp, memref::GetGlobalOp,
memref::SubViewOp, memref::SubViewOp,
@@ -36,7 +37,7 @@ static bool isBaseAddressableValue(Value value) {
Operation* defOp = value.getDefiningOp(); Operation* defOp = value.getDefiningOp();
if (!defOp) if (!defOp)
return false; return false;
if (isa<memref::AllocOp, memref::GetGlobalOp>(defOp)) if (isa<pim::PimEmptyManyOp, memref::AllocOp, memref::GetGlobalOp>(defOp))
return true; return true;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) { value = subview.getSource(); continue; } if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) { value = subview.getSource(); continue; }
if (auto cast = dyn_cast<memref::CastOp>(defOp)) { value = cast.getSource(); continue; } if (auto cast = dyn_cast<memref::CastOp>(defOp)) { value = cast.getSource(); continue; }
@@ -51,7 +52,7 @@ static bool isCodegenAddressableValue(Value value) {
if (failed(resolvedAddress)) if (failed(resolvedAddress))
return false; return false;
return isa<BlockArgument>(resolvedAddress->base) return isa<BlockArgument>(resolvedAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp()); || isa<pim::PimEmptyManyOp, memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
} }
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
@@ -184,7 +185,7 @@ private:
continue; continue;
} }
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) { if (!isa<pim::PimEmptyManyOp, memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd"; << " must be backed by device-local memory; materialize host values with pim.memcp_hd";
hasFailure = true; hasFailure = true;