reimplement pool lowering
add pool validation align PIM ops/codegen/parser with the ISA move constant materialization to MLIR rename the PIM verification/materialization passes better folded-constant handling
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
@@ -36,6 +37,25 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
|
||||
return memref::AllocOp::create(rewriter, loc, memrefResultType);
|
||||
}
|
||||
|
||||
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
|
||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
return pim::PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
contiguousBuffer.getType(),
|
||||
contiguousBuffer,
|
||||
memrefValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getDstOut();
|
||||
}
|
||||
|
||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||
|
||||
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
|
||||
@@ -167,7 +187,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
|
||||
auto memref = getBuffer(rewriter, operand, options, state);
|
||||
if (failed(memref))
|
||||
return failure();
|
||||
memrefOperands.push_back(*memref);
|
||||
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
|
||||
}
|
||||
|
||||
// TODO: Support addiction with more than 2 operands
|
||||
@@ -460,7 +480,7 @@ struct ChannelBroadcastSendOpInterface
|
||||
};
|
||||
|
||||
struct VAddOpInterfaceFromTemplate
|
||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVAddOp> {};
|
||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
|
||||
|
||||
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
|
||||
|
||||
@@ -468,9 +488,7 @@ struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, Spa
|
||||
|
||||
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
|
||||
|
||||
struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface<VSDivOpInterface, SpatVSDivOp, pim::PimVSDivOp> {};
|
||||
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVMaxOp> {};
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||
|
||||
// Create a new bufferizable op interface for the apply filters operation.
|
||||
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
|
||||
@@ -557,7 +575,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
||||
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
|
||||
SpatVSDivOp::attachInterface<VSDivOpInterface>(*ctx);
|
||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||
@@ -569,12 +586,16 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
|
||||
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
|
||||
|
||||
struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface<ONNXExpOpInterface, ONNXExpOp, pim::PimVExpOp> {};
|
||||
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
|
||||
|
||||
struct ONNXSigmoidInterface
|
||||
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
|
||||
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
|
||||
ONNXExpOp::attachInterface<ONNXExpOpInterface>(*ctx);
|
||||
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
|
||||
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user