add constant folding and verification pass for pim host operations
better validation scripts output big refactors
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
add_subdirectory(PIM)
|
||||
add_subdirectory(Spatial)
|
||||
add_subdirectory(Pim)
|
||||
add_subdirectory(Spatial)
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,34 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "Dialect/PIM/PimOps.hpp"
|
||||
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace pim {
|
||||
|
||||
#include "Dialect/PIM/Transforms/Bufferization/PimBufferization.hpp.inc"
|
||||
|
||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||
StringRef getArgument() const override { return "bufferize-pim"; }
|
||||
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
|
||||
|
||||
PimBufferizationPass() = default;
|
||||
PimBufferizationPass(const PimBufferizationPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
||||
};
|
||||
|
||||
} // namespace pim
|
||||
|
||||
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<pim::PimBufferizationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -175,7 +175,33 @@ def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
// Computation
|
||||
// Algebra
|
||||
|
||||
def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Matrix transpose
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $data,
|
||||
I64ArrayAttr: $perms,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
@@ -197,6 +223,10 @@ def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace pim {
|
||||
void PimDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
|
||||
|
||||
>();
|
||||
}
|
||||
@@ -45,5 +45,5 @@ POPULATE_DEPENDENCIES(PimVExpOp)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.cpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.cpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
|
||||
@@ -12,7 +12,7 @@
|
||||
#include <string>
|
||||
|
||||
/// Include the auto-generated header files containing the declarations
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.hpp.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp.inc"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp.inc"
|
||||
@@ -3,7 +3,6 @@ mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(PimBufferizationIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMPimBufferization
|
||||
PimBufferizationPass.hpp
|
||||
PimBufferizationPass.cpp
|
||||
OpBufferizationInterfaces.hpp
|
||||
OpBufferizationInterfaces.cpp
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
IntegerAttr getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref);
|
||||
mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "OpBufferizationInterfaces.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
@@ -76,6 +76,32 @@ struct MemCopyDevToHostOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct TransposeOpBufferizeInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto transposeOp = cast<PimTransposeOp>(op);
|
||||
|
||||
auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state);
|
||||
if (failed(dataOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimTransposeOp>(
|
||||
rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -176,6 +202,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
|
||||
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
|
||||
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
|
||||
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
void registerOpBufferizationInterfaces(mlir::DialectRegistry& registry);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -4,7 +4,7 @@
|
||||
#ifndef OP_BASE
|
||||
include "mlir/IR/PatternBase.td"
|
||||
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
|
||||
include "src/Accelerators/PIM/Dialect/PIM/Pim.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def memrefCopyToPimMemCopyOp : Pat<
|
||||
@@ -16,4 +16,4 @@ def memrefCopyToPimMemCopyOp : Pat<
|
||||
(returnType $dst))
|
||||
>;
|
||||
|
||||
#endif // PIM_BUFFERIZATION
|
||||
#endif // PIM_BUFFERIZATION
|
||||
@@ -7,12 +7,37 @@
|
||||
|
||||
#include "Common/PIMCommon.hpp"
|
||||
#include "Compiler/PimCodeGen.hpp"
|
||||
#include "PimBufferizationPass.hpp"
|
||||
#include "Dialect/Pim/PimOps.hpp"
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace pim;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
|
||||
|
||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||
StringRef getArgument() const override { return "bufferize-pim"; }
|
||||
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
|
||||
|
||||
PimBufferizationPass() = default;
|
||||
PimBufferizationPass(const PimBufferizationPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void PimBufferizationPass::runOnOperation() {
|
||||
auto moduleOp = getOperation();
|
||||
|
||||
@@ -68,15 +93,18 @@ void PimBufferizationPass::runOnOperation() {
|
||||
}
|
||||
|
||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||
MLIRContext* ctx = funcOp.getContext();
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
|
||||
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
|
||||
if (isAlwaysWeight) {
|
||||
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||
getGlobalOp->setAttr("weightAlways", UnitAttr::get(ctx));
|
||||
globalMemrefOp->setAttr("weightAlways", UnitAttr::get(ctx));
|
||||
markWeightAlways(getGlobalOp);
|
||||
markWeightAlways(globalMemrefOp);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -25,7 +25,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -202,9 +202,9 @@ LogicalResult SpatVMaxOp::verify() {
|
||||
|
||||
LogicalResult SpatImgConcatOp::verify() {
|
||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
||||
size_t img_w = GET_IMAGE_WIDTH(imgShape);
|
||||
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
|
||||
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
|
||||
size_t img_w = getImageWidth(imgShape);
|
||||
size_t img_h = getImageHeight(imgShape);
|
||||
size_t img_c = getImageChannel(imgShape);
|
||||
|
||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
||||
size_t channelTileRest = img_c % crossbarSize;
|
||||
@@ -225,10 +225,10 @@ LogicalResult SpatImgConcatOp::verify() {
|
||||
return emitError("Invalid input type, must be ShapedType");
|
||||
|
||||
// N == W == H == 1
|
||||
if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1)
|
||||
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
|
||||
return emitError("Invalid input shape: N,W,H must all be 1");
|
||||
|
||||
size_t inputChannels = GET_IMAGE_CHANNEL(inputShape);
|
||||
size_t inputChannels = getImageChannel(inputShape);
|
||||
|
||||
// Check the number of channels in this tile are correct:
|
||||
// - CASE1: last tile of pixel, if there is some rest it must match that
|
||||
@@ -311,9 +311,9 @@ LogicalResult SpatWeightedCompute::verify() {
|
||||
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
|
||||
auto operands = getOperands();
|
||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
||||
size_t img_w = GET_IMAGE_WIDTH(imgShape);
|
||||
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
|
||||
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
|
||||
size_t img_w = getImageWidth(imgShape);
|
||||
size_t img_h = getImageHeight(imgShape);
|
||||
size_t img_c = getImageChannel(imgShape);
|
||||
|
||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
#include <cstdint>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -4,14 +4,12 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
|
||||
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user