add relu lowering
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
add relu validation add spatial compute helper minor refactors
This commit is contained in:
@@ -363,29 +363,6 @@ def PimVVDMulOp : PimOp<"vvdmul", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimSumOp : PimOp<"sum", [DestinationStyleOpInterface]> {
|
||||
let summary = "Reduce all elements to a single value";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> {
|
||||
let summary = "Average all elements into a single value";
|
||||
|
||||
|
||||
@@ -97,8 +97,7 @@ struct MemCopyDevToHostOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct TransposeOpBufferizeInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
|
||||
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
@@ -123,7 +122,7 @@ struct TransposeOpBufferizeInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
|
||||
struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface, PimVMMOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
@@ -160,7 +159,7 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
|
||||
}
|
||||
};
|
||||
|
||||
struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBufferizeInterface, PimMVMOp> {
|
||||
struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface, PimMVMOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
@@ -186,8 +185,7 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct BinaryDstOpBufferizeInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
|
||||
struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
@@ -225,17 +223,56 @@ struct BinaryDstOpBufferizeInterface
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpInterface<OpTy>, OpTy> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
|
||||
const AnalysisState& state,
|
||||
ArrayRef<OpOperand*> opOperands) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto unaryOp = cast<OpTy>(op);
|
||||
|
||||
auto inputOpt = getBuffer(rewriter, unaryOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
auto outputBufferOpt = getBuffer(rewriter, unaryOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outputBufferOpt->getType(), contiguousInput, *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
|
||||
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
|
||||
PimVVAddOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVAddOp>>(*ctx);
|
||||
PimVVSubOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVSubOp>>(*ctx);
|
||||
PimVVMulOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMulOp>>(*ctx);
|
||||
PimVVMaxOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMaxOp>>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||
PimMVMOp::attachInterface<MVMOpInterface>(*ctx);
|
||||
|
||||
PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx);
|
||||
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
|
||||
PimVVMulOp::attachInterface<BinaryDstOpInterface<PimVVMulOp>>(*ctx);
|
||||
PimVVMaxOp::attachInterface<BinaryDstOpInterface<PimVVMaxOp>>(*ctx);
|
||||
|
||||
PimVAvgOp::attachInterface<UnaryDstOpInterface<PimVAvgOp>>(*ctx);
|
||||
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
|
||||
PimVTanhOp::attachInterface<UnaryDstOpInterface<PimVTanhOp>>(*ctx);
|
||||
PimVSigmOp::attachInterface<UnaryDstOpInterface<PimVSigmOp>>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -16,5 +16,4 @@ def memrefCopyToPimMemCopyOp : Pat<
|
||||
(returnType $dst))
|
||||
>;
|
||||
|
||||
|
||||
#endif // PIM_BUFFERIZATION
|
||||
|
||||
@@ -105,6 +105,6 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -486,8 +486,6 @@ struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, Spa
|
||||
|
||||
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
|
||||
|
||||
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
|
||||
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
@@ -496,7 +494,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
|
||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
||||
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
|
||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||
|
||||
Reference in New Issue
Block a user