reimplement pool lowering
add pool validation align PIM ops/codegen/parser with the ISA move constant materialization to MLIR rename the PIM verification/materialization passes better folded-constant handling
This commit is contained in:
@@ -530,6 +530,7 @@ where
|
||||
let r2_val = r2;
|
||||
ensure!(r2_val == 1, "Stride different than 1 not supported");
|
||||
let rd_val = core.register(rd);
|
||||
ensure!(offset_select == 1, "Offset select cannot be different from 1");
|
||||
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
||||
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
|
||||
let load1 = loads[0];
|
||||
|
||||
@@ -224,7 +224,21 @@ fn json_to_vvsub(
|
||||
inst_data_builder: &mut InstructionDataBuilder,
|
||||
json: &Value,
|
||||
) -> Result<()> {
|
||||
todo!("Not present in the compiler");
|
||||
let json = json.as_object().expect("Not an object");
|
||||
assert_eq!("vvsub", json_str!(json, "op"));
|
||||
let rd = json_i64!(json, "rd") as i32;
|
||||
let rs1 = json_i64!(json, "rs1") as i32;
|
||||
let rs2 = json_i64!(json, "rs2") as i32;
|
||||
let len = json_i64!(json, "len") as i32;
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_r1(rs1)
|
||||
.set_r2(rs2)
|
||||
.set_imm_len(len)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
inst_builder.make_inst(vvsub, inst_data_builder.build());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -256,7 +270,21 @@ fn json_to_vvdmul(
|
||||
inst_data_builder: &mut InstructionDataBuilder,
|
||||
json: &Value,
|
||||
) -> Result<()> {
|
||||
todo!("Not present in the compiler");
|
||||
let json = json.as_object().expect("Not an object");
|
||||
assert_eq!("vvdmul", json_str!(json, "op"));
|
||||
let rd = json_i64!(json, "rd") as i32;
|
||||
let rs1 = json_i64!(json, "rs1") as i32;
|
||||
let rs2 = json_i64!(json, "rs2") as i32;
|
||||
let len = json_i64!(json, "len") as i32;
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_r1(rs1)
|
||||
.set_r2(rs2)
|
||||
.set_imm_len(len)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
inst_builder.make_inst(vvdmul, inst_data_builder.build());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -306,7 +334,21 @@ fn json_to_vavg(
|
||||
inst_data_builder: &mut InstructionDataBuilder,
|
||||
json: &Value,
|
||||
) -> Result<()> {
|
||||
todo!("Not present in the compiler");
|
||||
let json = json.as_object().expect("Not an object");
|
||||
assert_eq!("vavg", json_str!(json, "op"));
|
||||
let rd = json_i64!(json, "rd") as i32;
|
||||
let rs1 = json_i64!(json, "rs1") as i32;
|
||||
let rs2 = json_i64!(json, "rs2") as i32;
|
||||
let len = json_i64!(json, "len") as i32;
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_r1(rs1)
|
||||
.set_r2(rs2)
|
||||
.set_imm_len(len)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
inst_builder.make_inst(vavg, inst_data_builder.build());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -358,7 +400,7 @@ fn json_to_vsigm(
|
||||
json: &Value,
|
||||
) -> Result<()> {
|
||||
let json = json.as_object().expect("Not an object");
|
||||
assert_eq!("vsigmoid", json_str!(json, "op"));
|
||||
assert_eq!("vsigm", json_str!(json, "op"));
|
||||
let rd = json_i64!(json, "rd") as i32;
|
||||
let rs1 = json_i64!(json, "rs1") as i32;
|
||||
let len = json_i64!(json, "len") as i32;
|
||||
|
||||
@@ -237,14 +237,16 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeM
|
||||
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vaddOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vaddOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vaddOp.getB());
|
||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||
static size_t getValueSizeInBytes(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
auto outputType = cast<MemRefType>(vaddOp.getOutBuf().getType());
|
||||
size_t totalBytes = outputType.getNumElements() * vaddOp.getOutRes().getType().getElementTypeBitWidth() / 8;
|
||||
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vvaddOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vvaddOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vvaddOp.getB());
|
||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvadd";
|
||||
@@ -252,14 +254,46 @@ void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = totalBytes;
|
||||
json["len"] = getValueSizeInBytes(vvaddOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vmaxOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vmaxOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vmaxOp.getB());
|
||||
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vvsubOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vvsubOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vvsubOp.getB());
|
||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvsub";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvsubOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vvmulOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vvmulOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vvmulOp.getB());
|
||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvmul";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvmulOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vvmaxOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vvmaxOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vvmaxOp.getB());
|
||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
@@ -268,6 +302,37 @@ void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvmaxOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vvdmulOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vvdmulOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vvdmulOp.getB());
|
||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvdmul";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvdmulOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vavgOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vavgOp.getA());
|
||||
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vavg";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vavgOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
@@ -281,6 +346,35 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vreluOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vtanhOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vtanhOp.getA());
|
||||
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vtanh";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vtanhOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vsigmOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vsigmOp.getA());
|
||||
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vsigm";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vsigmOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
@@ -338,6 +432,7 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co
|
||||
vaddJson["rs1"] = 1;
|
||||
vaddJson["rs2"] = 2;
|
||||
vaddJson["offset"] = createEmptyOffset();
|
||||
vaddJson["len"] = 32 * outChannels;
|
||||
emitInstruction(std::move(vaddJson));
|
||||
}
|
||||
}
|
||||
@@ -479,13 +574,25 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
|
||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||
coreCodeGen.codeGenTransposeOp(transposeOp);
|
||||
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op))
|
||||
coreCodeGen.codeGenVAddOp(vaddOp);
|
||||
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op))
|
||||
coreCodeGen.codeGenVMaxOp(vmaxOp);
|
||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||
coreCodeGen.codeGenVVAddOp(vvaddOp);
|
||||
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
|
||||
coreCodeGen.codeGenVVSubOp(vvsubOp);
|
||||
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
|
||||
coreCodeGen.codeGenVVMulOp(vvmulOp);
|
||||
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
|
||||
coreCodeGen.codeGenVVMaxOp(vvmaxOp);
|
||||
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
|
||||
coreCodeGen.codeGenVVDMulOp(vvdmulOp);
|
||||
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
|
||||
coreCodeGen.codeGenVAvgOp(vavgOp);
|
||||
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
|
||||
coreCodeGen.codeGenVReluOp(vreluOp);
|
||||
else if (isa<pim::PimSumOp, pim::PimVSDivOp, pim::PimVExpOp>(op)) {
|
||||
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
|
||||
coreCodeGen.codeGenVTanhOp(vtanhOp);
|
||||
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
|
||||
coreCodeGen.codeGenVSigmOp(vsigmOp);
|
||||
else if (isa<pim::PimSumOp>(op)) {
|
||||
// TODO: Implement somehow?
|
||||
op.emitWarning("Operation is not yet supported in code generation");
|
||||
continue;
|
||||
|
||||
@@ -90,9 +90,15 @@ public:
|
||||
template <typename MVMTy>
|
||||
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
|
||||
|
||||
void codeGenVAddOp(pim::PimVAddOp vaddOp) const;
|
||||
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const;
|
||||
void codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const;
|
||||
void codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const;
|
||||
void codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const;
|
||||
void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const;
|
||||
void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const;
|
||||
void codeGenVAvgOp(pim::PimVAvgOp vavgOp) const;
|
||||
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
|
||||
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
|
||||
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
|
||||
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
|
||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
|
||||
};
|
||||
|
||||
@@ -47,8 +47,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||
pm.addPass(createPimConstantFoldingPass());
|
||||
pm.addPass(createMessagePass("Pim constants folded"));
|
||||
pm.addPass(createPimHostVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim host verified"));
|
||||
pm.addPass(createPimMaterializeConstantsPass());
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
pm.addPass(createEmitPimJsonPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim json code emitted"));
|
||||
|
||||
@@ -6,7 +6,7 @@ add_pim_library(OMONNXToSpatial
|
||||
Patterns/Math/Gemm.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/MatMul.cpp
|
||||
Patterns/NN/Pooling.cpp
|
||||
Patterns/NN/Pool.cpp
|
||||
Patterns/NN/ReduceMean.cpp
|
||||
Patterns/Tensor/Concat.cpp
|
||||
Patterns/Tensor/Reshape.cpp
|
||||
|
||||
@@ -93,7 +93,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
patterns.add<removeLRNPattern>(ctx);
|
||||
|
||||
populateConvOpPatterns(patterns, ctx);
|
||||
populatePoolingTilingPattern(patterns, ctx);
|
||||
populatePoolTilingPattern(patterns, ctx);
|
||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||
populateReshapeConversionPattern(patterns, ctx);
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIR
|
||||
|
||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
|
||||
265
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp
Normal file
265
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp
Normal file
@@ -0,0 +1,265 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
template <typename ArrayAttrT>
|
||||
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
|
||||
return cast<IntegerAttr>(arrayAttr[index]).getInt();
|
||||
}
|
||||
|
||||
template <typename ArrayAttrT>
|
||||
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
|
||||
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||
}
|
||||
|
||||
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
||||
assert(!values.empty() && "Expected at least one value to concatenate.");
|
||||
if (values.size() == 1)
|
||||
return values.front();
|
||||
return tensor::ConcatOp::create(rewriter, loc, axis, values);
|
||||
}
|
||||
|
||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||
auto tileType = cast<RankedTensorType>(tile.getType());
|
||||
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
||||
|
||||
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(tileType.getRank());
|
||||
for (int64_t dimSize : tileType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dimSize));
|
||||
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
|
||||
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
template <typename ReduceOp>
|
||||
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
|
||||
assert(!windowValues.empty() && "Expected at least one pool window value.");
|
||||
|
||||
Value reduced = windowValues.front();
|
||||
for (Value value : windowValues.drop_front())
|
||||
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
|
||||
return reduced;
|
||||
}
|
||||
|
||||
static Value
|
||||
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
|
||||
assert(divisor > 0 && "AveragePool divisor must be positive.");
|
||||
if (divisor == 1)
|
||||
return reducedWindow;
|
||||
|
||||
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
|
||||
double scale = 1.0 / static_cast<double>(divisor);
|
||||
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
||||
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
||||
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
struct PoolToSpatialCompute;
|
||||
|
||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||
struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
using OpConversionPattern<PoolOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Location loc = poolOp.getLoc();
|
||||
Value x = adaptor.getX();
|
||||
|
||||
auto xType = dyn_cast<RankedTensorType>(x.getType());
|
||||
auto outType = dyn_cast<RankedTensorType>(poolOp.getResult().getType());
|
||||
if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types.");
|
||||
if (xType.getRank() != 4 || outType.getRank() != 4)
|
||||
return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported.");
|
||||
|
||||
ArrayAttr kernelAttr = poolOp.getKernelShape();
|
||||
if (!kernelAttr || kernelAttr.size() != 2)
|
||||
return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel.");
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t channels = xType.getDimSize(1);
|
||||
const int64_t inputHeight = xType.getDimSize(2);
|
||||
const int64_t inputWidth = xType.getDimSize(3);
|
||||
const int64_t outputHeight = outType.getDimSize(2);
|
||||
const int64_t outputWidth = outType.getDimSize(3);
|
||||
const int64_t kernelHeight = getI64(kernelAttr, 0);
|
||||
const int64_t kernelWidth = getI64(kernelAttr, 1);
|
||||
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
|
||||
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
|
||||
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
|
||||
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
|
||||
|
||||
int64_t padTop = 0;
|
||||
int64_t padLeft = 0;
|
||||
int64_t padBottom = 0;
|
||||
int64_t padRight = 0;
|
||||
|
||||
if (auto padsAttr = poolOp.getPads()) {
|
||||
if (padsAttr->size() != 4)
|
||||
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
|
||||
padTop = getI64(*padsAttr, 0);
|
||||
padLeft = getI64(*padsAttr, 1);
|
||||
padBottom = getI64(*padsAttr, 2);
|
||||
padRight = getI64(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
StringRef autoPad = poolOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max<int64_t>(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight);
|
||||
const int64_t totalPadW = std::max<int64_t>(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padTop = totalPadH / 2;
|
||||
padBottom = totalPadH - padTop;
|
||||
padLeft = totalPadW / 2;
|
||||
padRight = totalPadW - padLeft;
|
||||
}
|
||||
else {
|
||||
padBottom = totalPadH / 2;
|
||||
padTop = totalPadH - padBottom;
|
||||
padRight = totalPadW / 2;
|
||||
padLeft = totalPadW - padRight;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value.");
|
||||
}
|
||||
}
|
||||
|
||||
(void) padBottom;
|
||||
(void) padRight;
|
||||
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x});
|
||||
|
||||
auto* computeBlock = new Block();
|
||||
computeBlock->addArgument(xType, loc);
|
||||
computeOp.getBody().push_back(computeBlock);
|
||||
rewriter.setInsertionPointToStart(computeBlock);
|
||||
|
||||
Value input = computeBlock->getArgument(0);
|
||||
SmallVector<Value> batchResults;
|
||||
batchResults.reserve(batchSize);
|
||||
|
||||
for (int64_t batch = 0; batch < batchSize; ++batch) {
|
||||
SmallVector<Value> rows;
|
||||
rows.reserve(outputHeight);
|
||||
|
||||
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
||||
SmallVector<Value> rowPixels;
|
||||
rowPixels.reserve(outputWidth);
|
||||
|
||||
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
||||
SmallVector<Value> outputChannelTiles;
|
||||
outputChannelTiles.reserve(channelTileCount);
|
||||
|
||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> windowValues;
|
||||
windowValues.reserve(kernelHeight * kernelWidth);
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
||||
if (inH < 0 || inH >= inputHeight)
|
||||
continue;
|
||||
|
||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
||||
if (inW < 0 || inW >= inputWidth)
|
||||
continue;
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
rewriter.getIndexAttr(inH),
|
||||
rewriter.getIndexAttr(inW)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
Value windowValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides);
|
||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||
windowValues.push_back(windowValue);
|
||||
}
|
||||
}
|
||||
|
||||
if (windowValues.empty())
|
||||
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
||||
|
||||
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
const int64_t divisor =
|
||||
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
||||
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
||||
}
|
||||
|
||||
outputChannelTiles.push_back(reducedWindow);
|
||||
}
|
||||
|
||||
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
||||
}
|
||||
|
||||
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
||||
}
|
||||
|
||||
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
||||
}
|
||||
|
||||
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
||||
|
||||
rewriter.replaceOp(poolOp, computeOp.getResult(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>
|
||||
: public PoolToSpatialComputeBase<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp> {
|
||||
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PoolToSpatialCompute<ONNXAveragePoolOp>
|
||||
: public PoolToSpatialComputeBase<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp> {
|
||||
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
|
||||
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,427 +0,0 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
std::function<Value(const Value&, const Value&)> reduce,
|
||||
std::function<Value(const Value&)> preprocess,
|
||||
std::function<Value(const Value&)> postprocess) {
|
||||
// Simple case: if we have only one input, just return it
|
||||
if (valuesToReduce.size() == 1)
|
||||
return valuesToReduce[0];
|
||||
|
||||
if (preprocess) {
|
||||
for (auto& valToReduce : valuesToReduce) {
|
||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
||||
valToReduce = preprocess(valToReduce);
|
||||
}
|
||||
}
|
||||
|
||||
// It is possible that `valuesToReduce` contains two entries for the same
|
||||
// computeOp. In this case, we need to apply the reduction within-computef
|
||||
|
||||
// Keep a map between a computeOp and the last Value for this reduction
|
||||
std::unordered_map<Operation*, Value> lastValueForCompute;
|
||||
for (auto& valToReduce : valuesToReduce) {
|
||||
Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
|
||||
// if (valToReduce.getDefiningOp()) {
|
||||
// // If the value is defined by an operation, we take the parent
|
||||
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
|
||||
// } else {
|
||||
// // Otherwise it is a block argument,
|
||||
// computeOp->getBlock()->getParentOp();
|
||||
// }
|
||||
|
||||
assert(isa<spatial::SpatWeightedCompute>(computeOp) && "Expected a ComputeOp");
|
||||
|
||||
auto it = lastValueForCompute.find(computeOp);
|
||||
|
||||
if (it != lastValueForCompute.end()) {
|
||||
// If we have already seen this computeOp, apply the reduction
|
||||
// within-compute
|
||||
Value lastWithinComputeValue = it->second;
|
||||
|
||||
if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
|
||||
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
|
||||
else
|
||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
||||
valToReduce = reduce(lastWithinComputeValue, valToReduce);
|
||||
lastValueForCompute[computeOp] = valToReduce;
|
||||
}
|
||||
|
||||
lastValueForCompute[computeOp] = valToReduce;
|
||||
}
|
||||
|
||||
// Now, reconstruct from the map the valuesToReduce list
|
||||
valuesToReduce.clear();
|
||||
valuesToReduce.reserve(lastValueForCompute.size());
|
||||
for (auto& entry : lastValueForCompute)
|
||||
valuesToReduce.push_back(entry.second);
|
||||
|
||||
Location loc = valuesToReduce[0].getLoc();
|
||||
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
|
||||
|
||||
// Recursive algorithm to reduce the inputs to a single one:
|
||||
// - Take two inputs at a time, and reduce them into a single one, updating
|
||||
// the valuesToReduce list which becomes half the size.
|
||||
// - Repeat until there is only one input left.
|
||||
llvm::OwningArrayRef<Value> valuesToReduceRef(valuesToReduce);
|
||||
while (valuesToReduceRef.size() > 1) {
|
||||
SmallVector<Value> nextValuesToReduce;
|
||||
nextValuesToReduce.reserve(valuesToReduceRef.size() / 2);
|
||||
for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) {
|
||||
auto firstValue = valuesToReduceRef[i];
|
||||
auto secondValue = valuesToReduceRef[i + 1];
|
||||
|
||||
auto firstCompute = firstValue.getParentBlock()->getParentOp();
|
||||
auto secondCompute = secondValue.getParentBlock()->getParentOp();
|
||||
|
||||
assert(isa<spatial::SpatWeightedCompute>(firstCompute));
|
||||
assert(isa<spatial::SpatWeightedCompute>(secondCompute));
|
||||
|
||||
if (secondCompute->isBeforeInBlock(firstCompute)) {
|
||||
std::swap(firstValue, secondValue);
|
||||
std::swap(firstCompute, secondCompute);
|
||||
}
|
||||
|
||||
// 1. Add a channel before the first computeOp
|
||||
rewriter.setInsertionPoint(firstCompute);
|
||||
auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
||||
|
||||
// 2. Add a sendOp after the first value
|
||||
rewriter.setInsertionPointAfterValue(firstValue);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue);
|
||||
|
||||
// 3. Add a receiveOp after the second value
|
||||
rewriter.setInsertionPointAfterValue(secondValue);
|
||||
auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel);
|
||||
|
||||
// 4. Apply reduction between second value and received value
|
||||
rewriter.setInsertionPointAfterValue(receivedValue);
|
||||
Value reduced = reduce(receivedValue, secondValue);
|
||||
|
||||
nextValuesToReduce.push_back(reduced);
|
||||
}
|
||||
|
||||
// If we have an odd number of inputs, we need to add the last one to the
|
||||
// newInputs list.
|
||||
if (valuesToReduceRef.size() % 2 == 1)
|
||||
nextValuesToReduce.push_back(valuesToReduceRef.back());
|
||||
|
||||
// Replace the inputOps list with the new one.
|
||||
valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
|
||||
}
|
||||
|
||||
assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
|
||||
|
||||
auto finalValue = valuesToReduceRef[0];
|
||||
|
||||
if (postprocess) {
|
||||
rewriter.setInsertionPointAfterValue(finalValue);
|
||||
finalValue = postprocess(finalValue);
|
||||
}
|
||||
|
||||
return finalValue;
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
bool hasPostProcessPoolingWindow() {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
PoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
ONNXAveragePoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
|
||||
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
|
||||
|
||||
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
|
||||
|
||||
// Put a spat.const before the computeOp, and use its value. We do this to be
|
||||
// compatible with the current code generation, which assumes constant to be
|
||||
// loaded in global memory, which is allocated by adding a spat.const OP
|
||||
// directly under func.func (i.e. alongside ComputeOps)
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
auto divisorValue = spatial::SpatConstantOp::create(rewriter,
|
||||
loc,
|
||||
scalarTensor,
|
||||
rewriter.getI64IntegerAttr(divisorNumber),
|
||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.setInsertionPointAfterValue(valueToDivide);
|
||||
return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
||||
}
|
||||
|
||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
PoolingBaseConverter(MLIRContext* ctx)
|
||||
: OpConversionPattern<PoolOp>(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Value X = adaptor.getX();
|
||||
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
|
||||
Value Y = poolOp.getResult();
|
||||
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
|
||||
|
||||
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
|
||||
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
|
||||
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
|
||||
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
|
||||
|
||||
if (adaptor.getAutoPad() != "NOTSET")
|
||||
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
|
||||
|
||||
size_t pad_x, pad_y;
|
||||
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value())
|
||||
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
|
||||
|
||||
Location loc = poolOp.getLoc();
|
||||
|
||||
size_t input_h = getImageHeight(xShape);
|
||||
size_t input_w = getImageWidth(xShape);
|
||||
size_t output_h = getImageHeight(yShape);
|
||||
size_t output_w = getImageWidth(yShape);
|
||||
size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
|
||||
size_t channelTileRest = getImageChannel(xShape) % crossbarSize;
|
||||
|
||||
// 1: Tile the input tensor
|
||||
// Input tiles need to be indexed by:
|
||||
// a. Channel Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: inputTiles[channelTile][x][y]
|
||||
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
|
||||
// Suppose that the input tensor is produced by concatenating the results of
|
||||
// many ComputeOps. Get the result tiles from these ComputeOps.
|
||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
|
||||
channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
||||
|
||||
auto resolveErrorOpt =
|
||||
resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
||||
if (resolveErrorOpt.has_value())
|
||||
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
|
||||
|
||||
// TODO: This requires a core for each input tile, which is not ideal. We
|
||||
// can do better.
|
||||
// If some input tiles come from the func.func operands, load
|
||||
// them into a computeOp and yield them
|
||||
for (size_t t = 0; t < channelTileCount; t++) {
|
||||
for (size_t x = 0; x < input_w; x++) {
|
||||
for (size_t y = 0; y < input_h; y++) {
|
||||
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
Location tileLoc = extractSliceOp.getLoc();
|
||||
|
||||
auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter,
|
||||
tileLoc,
|
||||
extractSliceOp.getResultType(),
|
||||
/* xbarWeights =*/ValueRange(),
|
||||
extractSliceOp.getResult());
|
||||
|
||||
Block* tempComputeOpBlock = new Block();
|
||||
tempComputeOp.getBody().push_back(tempComputeOpBlock);
|
||||
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
|
||||
|
||||
rewriter.setInsertionPointToStart(tempComputeOpBlock);
|
||||
spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg);
|
||||
rewriter.setInsertionPointAfter(tempComputeOp);
|
||||
inputTiles[t][x][y] = tempComputeOp.getResult(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2: Tile the output tensor
|
||||
// Output tiles need to be indexed by:
|
||||
// a. Channel Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: outputTiles[channelTile][x][y]
|
||||
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
|
||||
channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
|
||||
|
||||
// List of values to pool for each output pixel
|
||||
SmallVector<Value> valuesToPool;
|
||||
|
||||
// Iterate each output tile
|
||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
||||
// Iterate each output pixel
|
||||
for (size_t outX = 0; outX < output_w; outX++) {
|
||||
for (size_t outY = 0; outY < output_h; outY++) {
|
||||
|
||||
// Each output pixel tile is computed by pooling a window of input
|
||||
// pixel tiles
|
||||
valuesToPool.clear();
|
||||
size_t tilesSkippedByPadding = 0;
|
||||
|
||||
auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
|
||||
auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
|
||||
|
||||
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
|
||||
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
|
||||
if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
|
||||
tilesSkippedByPadding++;
|
||||
continue;
|
||||
}
|
||||
|
||||
Value inputTile = inputTiles[outTile][inX][inY];
|
||||
|
||||
Value valueToPool;
|
||||
if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
|
||||
|
||||
int resultNumber = getResultIndex(computeProducer, inputTile);
|
||||
|
||||
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
|
||||
valueToPool = yieldInComputeOp.getOperand(resultNumber);
|
||||
}
|
||||
else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
|
||||
auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
|
||||
if (failed(sendOpOpt)) {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"ChannelReceiveOp does not have a matching "
|
||||
"ChannelSendOp.");
|
||||
}
|
||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
||||
|
||||
valueToPool = sendOp.getData();
|
||||
}
|
||||
else {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"Input tile for Pooling is not produced by a "
|
||||
"WeightedComputeOp nor a receiveOp");
|
||||
}
|
||||
|
||||
valuesToPool.push_back(valueToPool);
|
||||
}
|
||||
}
|
||||
|
||||
assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
|
||||
// assert(computeOpsForPooling.size() != 1 &&
|
||||
// "Pooling computed on one tiles make no sense??? Or maybe
|
||||
// this " "should have been simplified earlier???");
|
||||
|
||||
std::function<Value(const Value&)> postProcessFn = nullptr;
|
||||
if (hasPostProcessPoolingWindow<PoolOp>()) {
|
||||
postProcessFn = [&](const Value prevFinalRes) {
|
||||
return postProcessPoolingWindow(
|
||||
rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
|
||||
};
|
||||
}
|
||||
|
||||
Value reducedWithinCompute = applyReducePatternNew(
|
||||
valuesToPool,
|
||||
rewriter,
|
||||
[&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); },
|
||||
nullptr,
|
||||
postProcessFn);
|
||||
|
||||
// Send this value through a channel, and receive it in the
|
||||
// `func.func`. During lowering, we will need to "move it" into the
|
||||
// users computeOps
|
||||
auto computeOpOfReduced =
|
||||
cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
|
||||
|
||||
// Create a new channel before the computeOp
|
||||
rewriter.setInsertionPoint(computeOpOfReduced);
|
||||
auto reduceChannel =
|
||||
spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext()));
|
||||
|
||||
// Send value through the channel
|
||||
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute);
|
||||
|
||||
// Receive after the computeOp
|
||||
rewriter.setInsertionPointAfter(computeOpOfReduced);
|
||||
auto receivedValue =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel);
|
||||
|
||||
outputTiles[outTile][outX][outY] = receivedValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: outputTiles are not the results of the computeOps! We need to add
|
||||
// them!
|
||||
|
||||
std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
|
||||
|
||||
// Iterate each output tile
|
||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
||||
// Iterate each output pixel
|
||||
for (size_t outX = 0; outX < output_w; outX++) {
|
||||
for (size_t outY = 0; outY < output_h; outY++) {
|
||||
auto outputTile = outputTiles[outTile][outX][outY];
|
||||
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
|
||||
if (!outputTileProducer) {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"Output tile for Pooling is not produced by a "
|
||||
"WeightedComputeOp.");
|
||||
}
|
||||
|
||||
computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
|
||||
|
||||
rewriter.replaceOp(poolOp, outputImage);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
|
||||
ctx);
|
||||
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -27,9 +27,21 @@ def spatToPimMVMOp : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVAddOp : Pat<
|
||||
def spatToPimVVAddOp : Pat<
|
||||
(SpatVAddOp:$srcOpRes $a, $b),
|
||||
(PimVAddOp $a, $b,
|
||||
(PimVVAddOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMulOp : Pat<
|
||||
(SpatVMulOp:$srcOpRes $a, $b),
|
||||
(PimVVMulOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMaxOp : Pat<
|
||||
(SpatVMaxOp:$srcOpRes $a, $b),
|
||||
(PimVVMaxOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
|
||||
@@ -251,7 +251,7 @@ def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
|
||||
def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise addition: c = a + b
|
||||
}];
|
||||
@@ -277,7 +277,59 @@ def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise subtraction: c = a - b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise multiplication: c = a * b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise max: c = max(a, b)
|
||||
}];
|
||||
@@ -291,6 +343,32 @@ def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterfac
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Dot product: c = dot(a, b)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
@@ -332,14 +410,13 @@ def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>
|
||||
);
|
||||
}
|
||||
|
||||
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)
|
||||
Average all elements into a single one
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $dividend,
|
||||
PimTensor: $divisor,
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
@@ -363,9 +440,24 @@ def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterf
|
||||
);
|
||||
}
|
||||
|
||||
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise exp: c = exp(a)
|
||||
Element-wise tanh activation
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise sigmoid activation
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
@@ -388,4 +480,4 @@ def PimHaltOp: PimOp<"halt", [Terminator]> {
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // PIM_DIALECT_H
|
||||
#endif // PIM_DIALECT_H
|
||||
|
||||
@@ -30,12 +30,13 @@ void PimDialect::initialize() {
|
||||
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
|
||||
}
|
||||
|
||||
POPULATE_DEPENDENCIES(PimVMaxOp)
|
||||
POPULATE_DEPENDENCIES(PimVVDMulOp)
|
||||
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
|
||||
POPULATE_DEPENDENCIES(PimSumOp)
|
||||
POPULATE_DEPENDENCIES(PimVSDivOp)
|
||||
POPULATE_DEPENDENCIES(PimVAvgOp)
|
||||
POPULATE_DEPENDENCIES(PimVReluOp)
|
||||
POPULATE_DEPENDENCIES(PimVExpOp)
|
||||
POPULATE_DEPENDENCIES(PimVTanhOp)
|
||||
POPULATE_DEPENDENCIES(PimVSigmOp)
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "OpBufferizationInterfaces.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -12,6 +13,26 @@ using namespace bufferization;
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
static Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
return PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
contiguousType,
|
||||
contiguousBuffer,
|
||||
memrefValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getDstOut();
|
||||
}
|
||||
|
||||
struct MemCopyHostToDevOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
@@ -164,7 +185,8 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
|
||||
}
|
||||
};
|
||||
|
||||
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> {
|
||||
template <typename OpTy>
|
||||
struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
@@ -179,21 +201,24 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto vaddOp = cast<PimVAddOp>(op);
|
||||
auto binaryOp = cast<OpTy>(op);
|
||||
|
||||
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);
|
||||
auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state);
|
||||
if (failed(aOpt))
|
||||
return failure();
|
||||
|
||||
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state);
|
||||
auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state);
|
||||
if (failed(bOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state);
|
||||
auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt);
|
||||
Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter);
|
||||
Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -205,7 +230,10 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
|
||||
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
|
||||
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
|
||||
PimVVAddOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVAddOp>>(*ctx);
|
||||
PimVVSubOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVSubOp>>(*ctx);
|
||||
PimVVMulOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMulOp>>(*ctx);
|
||||
PimVVMaxOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMaxOp>>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
@@ -36,6 +37,25 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
|
||||
return memref::AllocOp::create(rewriter, loc, memrefResultType);
|
||||
}
|
||||
|
||||
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
|
||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
return pim::PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
contiguousBuffer.getType(),
|
||||
contiguousBuffer,
|
||||
memrefValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getDstOut();
|
||||
}
|
||||
|
||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||
|
||||
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
|
||||
@@ -167,7 +187,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
|
||||
auto memref = getBuffer(rewriter, operand, options, state);
|
||||
if (failed(memref))
|
||||
return failure();
|
||||
memrefOperands.push_back(*memref);
|
||||
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
|
||||
}
|
||||
|
||||
// TODO: Support addiction with more than 2 operands
|
||||
@@ -460,7 +480,7 @@ struct ChannelBroadcastSendOpInterface
|
||||
};
|
||||
|
||||
struct VAddOpInterfaceFromTemplate
|
||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVAddOp> {};
|
||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
|
||||
|
||||
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
|
||||
|
||||
@@ -468,9 +488,7 @@ struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, Spa
|
||||
|
||||
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
|
||||
|
||||
struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface<VSDivOpInterface, SpatVSDivOp, pim::PimVSDivOp> {};
|
||||
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVMaxOp> {};
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||
|
||||
// Create a new bufferizable op interface for the apply filters operation.
|
||||
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
|
||||
@@ -557,7 +575,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
||||
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
|
||||
SpatVSDivOp::attachInterface<VSDivOpInterface>(*ctx);
|
||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||
@@ -569,12 +586,16 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
|
||||
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
|
||||
|
||||
struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface<ONNXExpOpInterface, ONNXExpOp, pim::PimVExpOp> {};
|
||||
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
|
||||
|
||||
struct ONNXSigmoidInterface
|
||||
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
|
||||
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
|
||||
ONNXExpOp::attachInterface<ONNXExpOpInterface>(*ctx);
|
||||
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
|
||||
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ add_pim_library(OMPimPasses
|
||||
PimConstantFolding/Patterns/Constant.cpp
|
||||
PimConstantFolding/PimConstantFoldingPass.cpp
|
||||
PimConstantFolding/Patterns/Subview.cpp
|
||||
PimHostVerificationPass.cpp
|
||||
PimMaterializeConstantsPass.cpp
|
||||
PimVerificationPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
@@ -120,20 +120,8 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
rewriter.setInsertionPoint(coreOp);
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||
|
||||
size_t elementByteWidth = initType.getElementTypeBitWidth() / 8;
|
||||
if (elementByteWidth == 0)
|
||||
return failure();
|
||||
size_t totalBytes = initType.getNumElements() * elementByteWidth;
|
||||
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
mapOp.getLoc(),
|
||||
initType,
|
||||
mapOp.getInit(),
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
|
||||
rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
|
||||
rewriter.eraseOp(mapOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
135
src/PIM/Pass/PimMaterializeConstantsPass.cpp
Normal file
135
src/PIM/Pass/PimMaterializeConstantsPass.cpp
Normal file
@@ -0,0 +1,135 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
static int64_t getValueSizeInBytes(Value value) {
|
||||
auto type = dyn_cast<ShapedType>(value.getType());
|
||||
if (!type || !type.hasStaticShape())
|
||||
return -1;
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
struct PimMaterializeConstantsPass
|
||||
: PassWrapper<PimMaterializeConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass)
|
||||
|
||||
StringRef getArgument() const override { return "materialize-pim-constants"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
OpBuilder rewriter(moduleOp.getContext());
|
||||
bool hasFailure = false;
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) {
|
||||
if (isa<pim::PimHaltOp>(op))
|
||||
continue;
|
||||
|
||||
for (OpOperand& operand : op.getOpOperands()) {
|
||||
Value originalValue = operand.get();
|
||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(originalValue);
|
||||
if (failed(resolvedAddress))
|
||||
continue;
|
||||
|
||||
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
if (!getGlobalOp)
|
||||
continue;
|
||||
|
||||
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
|
||||
if (!originalType || !originalType.hasStaticShape()) {
|
||||
op.emitOpError("host constant materialization requires a static memref operand");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& cachedByOffset = materializedValues[resolvedAddress->base];
|
||||
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
|
||||
auto cachedValue = cachedByType.find(originalType);
|
||||
if (cachedValue != cachedByType.end()) {
|
||||
operand.set(cachedValue->second);
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t totalBytes = getValueSizeInBytes(originalValue);
|
||||
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
|
||||
op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(&op);
|
||||
Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType);
|
||||
Value deviceDst = localAlloc;
|
||||
if (contiguousType != originalType)
|
||||
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
|
||||
|
||||
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
originalType,
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(
|
||||
static_cast<int32_t>(resolvedAddress->byteOffset)),
|
||||
rewriter.getI32IntegerAttr(
|
||||
static_cast<int32_t>(totalBytes)));
|
||||
|
||||
cachedByType[originalType] = hostToDevCopy.getResult();
|
||||
operand.set(hostToDevCopy.getResult());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailure) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(moduleOp, "pim3_materialized");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimMaterializeConstantsPass() {
|
||||
return std::make_unique<PimMaterializeConstantsPass>();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -17,7 +17,9 @@ std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimHostVerificationPass();
|
||||
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
||||
|
||||
|
||||
@@ -35,16 +35,24 @@ static bool isCodegenAddressableValue(Value value) {
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass)
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
StringRef getArgument() const override { return "verify-pim-host-pass"; }
|
||||
struct PimVerificationPass : PassWrapper<PimVerificationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass)
|
||||
|
||||
StringRef getArgument() const override { return "verify-pim-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Verify that no runtime host-side code remains in bufferized PIM IR";
|
||||
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
|
||||
}
|
||||
|
||||
PimHostVerificationPass() {}
|
||||
PimHostVerificationPass(const PimHostVerificationPass& pass) {}
|
||||
PimVerificationPass() {}
|
||||
PimVerificationPass(const PimVerificationPass& pass) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
@@ -132,11 +140,27 @@ private:
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
continue;
|
||||
if (succeeded(resolveContiguousAddress(operand)))
|
||||
continue;
|
||||
|
||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
auto resolvedAddress = resolveContiguousAddress(operand);
|
||||
if (failed(resolvedAddress)) {
|
||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||
op.emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
@@ -165,6 +189,6 @@ private:
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimHostVerificationPass() { return std::make_unique<PimHostVerificationPass>(); }
|
||||
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<PimVerificationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -74,7 +74,8 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
||||
registerPass(createSpatialToPimPass);
|
||||
registerPass(createBufferizePimPass);
|
||||
registerPass(createPimConstantFoldingPass);
|
||||
registerPass(createPimHostVerificationPass);
|
||||
registerPass(createPimMaterializeConstantsPass);
|
||||
registerPass(createPimVerificationPass);
|
||||
registerPass(createEmitPimJsonPass);
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,18 @@ python3 validation/operations/gen_tests.py
|
||||
| With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias |
|
||||
| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input |
|
||||
|
||||
## Pool
|
||||
|
||||
| Test | Directory | Input | Output | Kernel | Stride | Padding | Notes |
|
||||
|------|-----------|-------|--------|--------|--------|---------|-------|
|
||||
| Max basic | `pool/max_basic` | [1,1,4,4] | [1,1,3,3] | 2x2 | 1 | none | Basic max pooling |
|
||||
| Max stride 2 multi-channel | `pool/max_stride2_multichannel` | [1,5,6,6] | [1,5,3,3] | 2x2 | 2 | none | Channel-preserving max pool |
|
||||
| Max SAME_UPPER | `pool/max_same_upper` | [1,1,5,5] | [1,1,3,3] | 3x3 | 2 | SAME_UPPER | Deprecated auto_pad path |
|
||||
| Avg basic | `pool/avg_basic` | [1,3,4,4] | [1,3,3,3] | 2x2 | 1 | none | Basic average pooling |
|
||||
| Avg explicit padding | `pool/avg_explicit_padding` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=0` |
|
||||
| Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` |
|
||||
| Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` |
|
||||
|
||||
## Gemm
|
||||
|
||||
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate ONNX test models for validating GEMM and Conv implementations."""
|
||||
"""Generate ONNX test models for validating GEMM, Conv, and Pooling implementations."""
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
@@ -248,6 +248,85 @@ def conv_large_spatial():
|
||||
save_model(model, "conv/large_spatial", "conv_large_spatial.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pooling tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def maxpool_basic():
|
||||
"""MaxPool 2x2 with stride 1."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
|
||||
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "maxpool_basic", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/max_basic", "maxpool_basic.onnx")
|
||||
|
||||
|
||||
def maxpool_stride2_multichannel():
|
||||
"""MaxPool 2x2 with stride 2 on multiple channels."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 5, 6, 6])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 5, 3, 3])
|
||||
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "maxpool_stride2_multichannel", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/max_stride2_multichannel", "maxpool_stride2_multichannel.onnx")
|
||||
|
||||
|
||||
def maxpool_same_upper():
|
||||
"""MaxPool 3x3 with SAME_UPPER padding."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
|
||||
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[3, 3], strides=[2, 2], auto_pad="SAME_UPPER")
|
||||
graph = helper.make_graph([node], "maxpool_same_upper", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/max_same_upper", "maxpool_same_upper.onnx")
|
||||
|
||||
|
||||
def avgpool_basic():
|
||||
"""AveragePool 2x2 with stride 1."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 3, 3])
|
||||
node = helper.make_node("AveragePool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "avgpool_basic", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/avg_basic", "avgpool_basic.onnx")
|
||||
|
||||
|
||||
def avgpool_explicit_padding():
|
||||
"""AveragePool 3x3 with explicit padding, excluding pad from the divisor."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2])
|
||||
node = helper.make_node("AveragePool", ["X"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=0)
|
||||
graph = helper.make_graph([node], "avgpool_explicit_padding", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/avg_explicit_padding", "avgpool_explicit_padding.onnx")
|
||||
|
||||
|
||||
def avgpool_include_pad():
|
||||
"""AveragePool 3x3 with explicit padding, including pad in the divisor."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2])
|
||||
node = helper.make_node("AveragePool", ["X"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=1)
|
||||
graph = helper.make_graph([node], "avgpool_include_pad", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/avg_include_pad", "avgpool_include_pad.onnx")
|
||||
|
||||
|
||||
def maxpool_after_conv():
|
||||
"""Conv followed by MaxPool to validate pooling on lowered conv results."""
|
||||
rng = np.random.default_rng(59)
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 6, 6])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 2, 2])
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
|
||||
conv = helper.make_node("Conv", ["X", "W"], ["C"], kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
pool = helper.make_node("MaxPool", ["C"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([conv, pool], "maxpool_after_conv", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -273,4 +352,13 @@ if __name__ == "__main__":
|
||||
conv_batch_2()
|
||||
conv_large_spatial()
|
||||
|
||||
print("\nGenerating Pooling tests:")
|
||||
maxpool_basic()
|
||||
maxpool_stride2_multichannel()
|
||||
maxpool_same_upper()
|
||||
avgpool_basic()
|
||||
avgpool_explicit_padding()
|
||||
avgpool_include_pad()
|
||||
maxpool_after_conv()
|
||||
|
||||
print("\nDone.")
|
||||
|
||||
BIN
validation/operations/pool/avg_basic/avgpool_basic.onnx
Normal file
BIN
validation/operations/pool/avg_basic/avgpool_basic.onnx
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
validation/operations/pool/max_basic/maxpool_basic.onnx
Normal file
BIN
validation/operations/pool/max_basic/maxpool_basic.onnx
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user