diff --git a/backend-simulators/pim/pim-simulator/src/lib/instruction_set/isa.rs b/backend-simulators/pim/pim-simulator/src/lib/instruction_set/isa.rs index fb10fe4..fc9d5be 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/instruction_set/isa.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/instruction_set/isa.rs @@ -530,6 +530,7 @@ where let r2_val = r2; ensure!(r2_val == 1, "Stride different than 1 not supported"); let rd_val = core.register(rd); + ensure!(offset_select == 1, "Offset select cannot be different from 1"); let r1_val = add_offset_r1(r1_val, offset_select, offset_value); let loads = core.reserve_load(r1_val, imm_len)?.execute_load::()?; let load1 = loads[0]; diff --git a/backend-simulators/pim/pim-simulator/src/lib/json_to_instruction/json_isa.rs b/backend-simulators/pim/pim-simulator/src/lib/json_to_instruction/json_isa.rs index 7d4aabe..068e639 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/json_to_instruction/json_isa.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/json_to_instruction/json_isa.rs @@ -224,7 +224,21 @@ fn json_to_vvsub( inst_data_builder: &mut InstructionDataBuilder, json: &Value, ) -> Result<()> { - todo!("Not present in the compiler"); + let json = json.as_object().expect("Not an object"); + assert_eq!("vvsub", json_str!(json, "op")); + let rd = json_i64!(json, "rd") as i32; + let rs1 = json_i64!(json, "rs1") as i32; + let rs2 = json_i64!(json, "rs2") as i32; + let len = json_i64!(json, "len") as i32; + let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap()); + inst_data_builder + .set_rd(rd) + .set_r1(rs1) + .set_r2(rs2) + .set_imm_len(len) + .set_offset_select(offset_select) + .set_offset_value(offset_value); + inst_builder.make_inst(vvsub, inst_data_builder.build()); Ok(()) } @@ -256,7 +270,21 @@ fn json_to_vvdmul( inst_data_builder: &mut InstructionDataBuilder, json: &Value, ) -> Result<()> { - todo!("Not present in the compiler"); + let json = json.as_object().expect("Not an object"); + assert_eq!("vvdmul", json_str!(json, "op")); + let rd = json_i64!(json, "rd") as i32; + let rs1 = json_i64!(json, "rs1") as i32; + let rs2 = json_i64!(json, "rs2") as i32; + let len = json_i64!(json, "len") as i32; + let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap()); + inst_data_builder + .set_rd(rd) + .set_r1(rs1) + .set_r2(rs2) + .set_imm_len(len) + .set_offset_select(offset_select) + .set_offset_value(offset_value); + inst_builder.make_inst(vvdmul, inst_data_builder.build()); Ok(()) } @@ -306,7 +334,21 @@ fn json_to_vavg( inst_data_builder: &mut InstructionDataBuilder, json: &Value, ) -> Result<()> { - todo!("Not present in the compiler"); + let json = json.as_object().expect("Not an object"); + assert_eq!("vavg", json_str!(json, "op")); + let rd = json_i64!(json, "rd") as i32; + let rs1 = json_i64!(json, "rs1") as i32; + let rs2 = json_i64!(json, "rs2") as i32; + let len = json_i64!(json, "len") as i32; + let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap()); + inst_data_builder + .set_rd(rd) + .set_r1(rs1) + .set_r2(rs2) + .set_imm_len(len) + .set_offset_select(offset_select) + .set_offset_value(offset_value); + inst_builder.make_inst(vavg, inst_data_builder.build()); Ok(()) } @@ -358,7 +400,7 @@ fn json_to_vsigm( json: &Value, ) -> Result<()> { let json = json.as_object().expect("Not an object"); - assert_eq!("vsigmoid", json_str!(json, "op")); + assert_eq!("vsigm", json_str!(json, "op")); let rd = json_i64!(json, "rd") as i32; let rs1 = json_i64!(json, "rs1") as i32; let len = json_i64!(json, "len") as i32; diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 50c539f..20b71c3 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -237,14 +237,16 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeM // TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix) } -void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const { - auto outBufAddr = memory.getValueAddress(vaddOp.getOutBuf()); - auto aAddr = memory.getValueAddress(vaddOp.getA()); - auto bAddr = memory.getValueAddress(vaddOp.getB()); - setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); +static size_t getValueSizeInBytes(mlir::Value value) { + auto type = cast(value.getType()); + return type.getNumElements() * type.getElementTypeBitWidth() / 8; +} - auto outputType = cast(vaddOp.getOutBuf().getType()); - size_t totalBytes = outputType.getNumElements() * vaddOp.getOutRes().getType().getElementTypeBitWidth() / 8; +void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const { + auto outBufAddr = memory.getValueAddress(vvaddOp.getOutBuf()); + auto aAddr = memory.getValueAddress(vvaddOp.getA()); + auto bAddr = memory.getValueAddress(vvaddOp.getB()); + setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); json::Object json; json["op"] = "vvadd"; @@ -252,14 +254,46 @@ void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const { json["rs1"] = 1; json["rs2"] = 2; json["offset"] = createEmptyOffset(); - json["len"] = totalBytes; + json["len"] = getValueSizeInBytes(vvaddOp.getA()); emitInstruction(std::move(json)); } -void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const { - auto outBufAddr = memory.getValueAddress(vmaxOp.getOutBuf()); - auto aAddr = memory.getValueAddress(vmaxOp.getA()); - auto bAddr = memory.getValueAddress(vmaxOp.getB()); +void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const { + auto outBufAddr = memory.getValueAddress(vvsubOp.getOutBuf()); + auto aAddr = memory.getValueAddress(vvsubOp.getA()); + auto bAddr = memory.getValueAddress(vvsubOp.getB()); + setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); + + json::Object json; + json["op"] = "vvsub"; + json["rd"] = 0; + json["rs1"] = 1; + json["rs2"] = 2; + json["offset"] = createEmptyOffset(); + json["len"] = getValueSizeInBytes(vvsubOp.getA()); + emitInstruction(std::move(json)); +} + +void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const { + auto outBufAddr = memory.getValueAddress(vvmulOp.getOutBuf()); + auto aAddr = memory.getValueAddress(vvmulOp.getA()); + auto bAddr = memory.getValueAddress(vvmulOp.getB()); + setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); + + json::Object json; + json["op"] = "vvmul"; + json["rd"] = 0; + json["rs1"] = 1; + json["rs2"] = 2; + json["offset"] = createEmptyOffset(); + json["len"] = getValueSizeInBytes(vvmulOp.getA()); + emitInstruction(std::move(json)); +} + +void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const { + auto outBufAddr = memory.getValueAddress(vvmaxOp.getOutBuf()); + auto aAddr = memory.getValueAddress(vvmaxOp.getA()); + auto bAddr = memory.getValueAddress(vvmaxOp.getB()); setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); json::Object json; @@ -268,6 +302,37 @@ void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const { json["rs1"] = 1; json["rs2"] = 2; json["offset"] = createEmptyOffset(); + json["len"] = getValueSizeInBytes(vvmaxOp.getA()); + emitInstruction(std::move(json)); +} + +void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const { + auto outBufAddr = memory.getValueAddress(vvdmulOp.getOutBuf()); + auto aAddr = memory.getValueAddress(vvdmulOp.getA()); + auto bAddr = memory.getValueAddress(vvdmulOp.getB()); + setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); + + json::Object json; + json["op"] = "vvdmul"; + json["rd"] = 0; + json["rs1"] = 1; + json["rs2"] = 2; + json["offset"] = createEmptyOffset(); + json["len"] = getValueSizeInBytes(vvdmulOp.getA()); + emitInstruction(std::move(json)); +} + +void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const { + auto outBufAddr = memory.getValueAddress(vavgOp.getOutBuf()); + auto aAddr = memory.getValueAddress(vavgOp.getA()); + setupRdRs1(outBufAddr, 0, aAddr, 0); + + json::Object json; + json["op"] = "vavg"; + json["rd"] = 0; + json["rs1"] = 1; + json["offset"] = createEmptyOffset(); + json["len"] = getValueSizeInBytes(vavgOp.getA()); emitInstruction(std::move(json)); } @@ -281,6 +346,35 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const { json["rd"] = 0; json["rs1"] = 1; json["offset"] = createEmptyOffset(); + json["len"] = getValueSizeInBytes(vreluOp.getA()); + emitInstruction(std::move(json)); +} + +void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const { + auto outBufAddr = memory.getValueAddress(vtanhOp.getOutBuf()); + auto aAddr = memory.getValueAddress(vtanhOp.getA()); + setupRdRs1(outBufAddr, 0, aAddr, 0); + + json::Object json; + json["op"] = "vtanh"; + json["rd"] = 0; + json["rs1"] = 1; + json["offset"] = createEmptyOffset(); + json["len"] = getValueSizeInBytes(vtanhOp.getA()); + emitInstruction(std::move(json)); +} + +void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const { + auto outBufAddr = memory.getValueAddress(vsigmOp.getOutBuf()); + auto aAddr = memory.getValueAddress(vsigmOp.getA()); + setupRdRs1(outBufAddr, 0, aAddr, 0); + + json::Object json; + json["op"] = "vsigm"; + json["rd"] = 0; + json["rs1"] = 1; + json["offset"] = createEmptyOffset(); + json["len"] = getValueSizeInBytes(vsigmOp.getA()); emitInstruction(std::move(json)); } @@ -338,6 +432,7 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co vaddJson["rs1"] = 1; vaddJson["rs2"] = 2; vaddJson["offset"] = createEmptyOffset(); + vaddJson["len"] = 32 * outChannels; emitInstruction(std::move(vaddJson)); } } @@ -479,13 +574,25 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp); else if (auto transposeOp = dyn_cast(op)) coreCodeGen.codeGenTransposeOp(transposeOp); - else if (auto vaddOp = dyn_cast(op)) - coreCodeGen.codeGenVAddOp(vaddOp); - else if (auto vmaxOp = dyn_cast(op)) - coreCodeGen.codeGenVMaxOp(vmaxOp); + else if (auto vvaddOp = dyn_cast(op)) + coreCodeGen.codeGenVVAddOp(vvaddOp); + else if (auto vvsubOp = dyn_cast(op)) + coreCodeGen.codeGenVVSubOp(vvsubOp); + else if (auto vvmulOp = dyn_cast(op)) + coreCodeGen.codeGenVVMulOp(vvmulOp); + else if (auto vvmaxOp = dyn_cast(op)) + coreCodeGen.codeGenVVMaxOp(vvmaxOp); + else if (auto vvdmulOp = dyn_cast(op)) + coreCodeGen.codeGenVVDMulOp(vvdmulOp); + else if (auto vavgOp = dyn_cast(op)) + coreCodeGen.codeGenVAvgOp(vavgOp); else if (auto vreluOp = dyn_cast(op)) coreCodeGen.codeGenVReluOp(vreluOp); - else if (isa(op)) { + else if (auto vtanhOp = dyn_cast(op)) + coreCodeGen.codeGenVTanhOp(vtanhOp); + else if (auto vsigmOp = dyn_cast(op)) + coreCodeGen.codeGenVSigmOp(vsigmOp); + else if (isa(op)) { // TODO: Implement somehow? op.emitWarning("Operation is not yet supported in code generation"); continue; diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 38c4bc0..e08dfda 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -90,9 +90,15 @@ public: template void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix); - void codeGenVAddOp(pim::PimVAddOp vaddOp) const; - void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const; + void codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const; + void codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const; + void codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const; + void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const; + void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const; + void codeGenVAvgOp(pim::PimVAvgOp vavgOp) const; void codeGenVReluOp(pim::PimVReluOp vreluOp) const; + void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const; + void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const; void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const; void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const; }; diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp index b346454..ba8af85 100644 --- a/src/PIM/Compiler/PimCompilerUtils.cpp +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -47,8 +47,9 @@ void addPassesPim(OwningOpRef& module, if (pimEmissionTarget >= EmitPimCodegen) { pm.addPass(createPimConstantFoldingPass()); pm.addPass(createMessagePass("Pim constants folded")); - pm.addPass(createPimHostVerificationPass()); - pm.addPass(createMessagePass("Pim host verified")); + pm.addPass(createPimMaterializeConstantsPass()); + pm.addPass(createPimVerificationPass()); + pm.addPass(createMessagePass("Pim verified")); pm.addPass(createEmitPimJsonPass()); // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Pim json code emitted")); diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 2ecfe24..383f7f3 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -6,7 +6,7 @@ add_pim_library(OMONNXToSpatial Patterns/Math/Gemm.cpp Patterns/Math/Conv.cpp Patterns/Math/MatMul.cpp - Patterns/NN/Pooling.cpp + Patterns/NN/Pool.cpp Patterns/NN/ReduceMean.cpp Patterns/Tensor/Concat.cpp Patterns/Tensor/Reshape.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index bb929e1..2ba333e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -93,7 +93,7 @@ void ONNXToSpatialPass::runOnOperation() { patterns.add(ctx); populateConvOpPatterns(patterns, ctx); - populatePoolingTilingPattern(patterns, ctx); + populatePoolTilingPattern(patterns, ctx); populateOnnxGemmOpPatterns(patterns, ctx); populateReshapeConversionPattern(patterns, ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp index 4311851..2b29f4a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp @@ -11,7 +11,7 @@ void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIR void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp new file mode 100644 index 0000000..f558c47 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp @@ -0,0 +1,265 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "llvm/ADT/SmallVector.h" + +#include +#include +#include +#include + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +template +static int64_t getI64(ArrayAttrT arrayAttr, size_t index) { + return cast(arrayAttr[index]).getInt(); +} + +template +static int64_t getOptionalI64(std::optional arrayAttr, size_t index, int64_t defaultValue) { + return arrayAttr ? getI64(*arrayAttr, index) : defaultValue; +} + +static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef values) { + assert(!values.empty() && "Expected at least one value to concatenate."); + if (values.size() == 1) + return values.front(); + return tensor::ConcatOp::create(rewriter, loc, axis, values); +} + +static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) { + auto tileType = cast(tile.getType()); + Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType()); + + SmallVector offsets(tileType.getRank(), rewriter.getIndexAttr(0)); + SmallVector sizes; + sizes.reserve(tileType.getRank()); + for (int64_t dimSize : tileType.getShape()) + sizes.push_back(rewriter.getIndexAttr(dimSize)); + SmallVector strides(tileType.getRank(), rewriter.getIndexAttr(1)); + + return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides); +} + +template +static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef windowValues) { + assert(!windowValues.empty() && "Expected at least one pool window value."); + + Value reduced = windowValues.front(); + for (Value value : windowValues.drop_front()) + reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value); + return reduced; +} + +static Value +scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) { + assert(divisor > 0 && "AveragePool divisor must be positive."); + if (divisor == 1) + return reducedWindow; + + auto tileType = cast(reducedWindow.getType()); + double scale = 1.0 / static_cast(divisor); + auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale)); + Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr); + return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor); +} + +template +struct PoolToSpatialCompute; + +template +struct PoolToSpatialComputeBase : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { + Location loc = poolOp.getLoc(); + Value x = adaptor.getX(); + + auto xType = dyn_cast(x.getType()); + auto outType = dyn_cast(poolOp.getResult().getType()); + if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape()) + return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types."); + if (xType.getRank() != 4 || outType.getRank() != 4) + return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported."); + + ArrayAttr kernelAttr = poolOp.getKernelShape(); + if (!kernelAttr || kernelAttr.size() != 2) + return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel."); + + const int64_t batchSize = xType.getDimSize(0); + const int64_t channels = xType.getDimSize(1); + const int64_t inputHeight = xType.getDimSize(2); + const int64_t inputWidth = xType.getDimSize(3); + const int64_t outputHeight = outType.getDimSize(2); + const int64_t outputWidth = outType.getDimSize(3); + const int64_t kernelHeight = getI64(kernelAttr, 0); + const int64_t kernelWidth = getI64(kernelAttr, 1); + const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1); + const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1); + const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1); + const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1); + + int64_t padTop = 0; + int64_t padLeft = 0; + int64_t padBottom = 0; + int64_t padRight = 0; + + if (auto padsAttr = poolOp.getPads()) { + if (padsAttr->size() != 4) + return rewriter.notifyMatchFailure(poolOp, "pads must have four elements."); + padTop = getI64(*padsAttr, 0); + padLeft = getI64(*padsAttr, 1); + padBottom = getI64(*padsAttr, 2); + padRight = getI64(*padsAttr, 3); + } + else { + StringRef autoPad = poolOp.getAutoPad(); + if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { + const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1; + const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1; + const int64_t totalPadH = + std::max(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight); + const int64_t totalPadW = std::max(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth); + + if (autoPad == "SAME_UPPER") { + padTop = totalPadH / 2; + padBottom = totalPadH - padTop; + padLeft = totalPadW / 2; + padRight = totalPadW - padLeft; + } + else { + padBottom = totalPadH / 2; + padTop = totalPadH - padBottom; + padRight = totalPadW / 2; + padLeft = totalPadW - padRight; + } + } + else if (autoPad != "NOTSET" && autoPad != "VALID") { + return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value."); + } + } + + (void) padBottom; + (void) padRight; + + const int64_t xbarSize = static_cast(crossbarSize.getValue()); + const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize; + auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector(), ValueRange {x}); + + auto* computeBlock = new Block(); + computeBlock->addArgument(xType, loc); + computeOp.getBody().push_back(computeBlock); + rewriter.setInsertionPointToStart(computeBlock); + + Value input = computeBlock->getArgument(0); + SmallVector batchResults; + batchResults.reserve(batchSize); + + for (int64_t batch = 0; batch < batchSize; ++batch) { + SmallVector rows; + rows.reserve(outputHeight); + + for (int64_t outH = 0; outH < outputHeight; ++outH) { + SmallVector rowPixels; + rowPixels.reserve(outputWidth); + + for (int64_t outW = 0; outW < outputWidth; ++outW) { + SmallVector outputChannelTiles; + outputChannelTiles.reserve(channelTileCount); + + for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { + const int64_t tileChannels = std::min(xbarSize, channels - channelTile * xbarSize); + auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); + + SmallVector windowValues; + windowValues.reserve(kernelHeight * kernelWidth); + for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { + const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop; + if (inH < 0 || inH >= inputHeight) + continue; + + for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { + const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft; + if (inW < 0 || inW >= inputWidth) + continue; + + SmallVector offsets = {rewriter.getIndexAttr(batch), + rewriter.getIndexAttr(channelTile * xbarSize), + rewriter.getIndexAttr(inH), + rewriter.getIndexAttr(inW)}; + SmallVector sizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tileChannels), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + SmallVector strides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + Value windowValue = + tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides); + windowValue = materializeContiguousTile(rewriter, loc, windowValue); + windowValues.push_back(windowValue); + } + } + + if (windowValues.empty()) + return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements."); + + Value reducedWindow = reduceWindowValues(rewriter, loc, windowValues); + if constexpr (std::is_same_v) { + const bool countIncludePad = poolOp.getCountIncludePad() == 1; + const int64_t divisor = + countIncludePad ? kernelHeight * kernelWidth : static_cast(windowValues.size()); + reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor); + } + + outputChannelTiles.push_back(reducedWindow); + } + + rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles)); + } + + rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels)); + } + + batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows)); + } + + Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults); + spatial::SpatYieldOp::create(rewriter, loc, pooledOutput); + + rewriter.replaceOp(poolOp, computeOp.getResult(0)); + return success(); + } +}; + +template <> +struct PoolToSpatialCompute +: public PoolToSpatialComputeBase { + using PoolToSpatialComputeBase::PoolToSpatialComputeBase; +}; + +template <> +struct PoolToSpatialCompute +: public PoolToSpatialComputeBase { + using PoolToSpatialComputeBase::PoolToSpatialComputeBase; +}; + +} // namespace + +void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.insert>(ctx); + patterns.insert>(ctx); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pooling.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pooling.cpp deleted file mode 100644 index 6b4c104..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pooling.cpp +++ /dev/null @@ -1,427 +0,0 @@ -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" - -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" - -#include -#include -#include - -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { - -Value applyReducePatternNew(SmallVector& valuesToReduce, - ConversionPatternRewriter& rewriter, - std::function reduce, - std::function preprocess, - std::function postprocess) { - // Simple case: if we have only one input, just return it - if (valuesToReduce.size() == 1) - return valuesToReduce[0]; - - if (preprocess) { - for (auto& valToReduce : valuesToReduce) { - rewriter.setInsertionPointAfterValue(valToReduce); - valToReduce = preprocess(valToReduce); - } - } - - // It is possible that `valuesToReduce` contains two entries for the same - // computeOp. In this case, we need to apply the reduction within-computef - - // Keep a map between a computeOp and the last Value for this reduction - std::unordered_map lastValueForCompute; - for (auto& valToReduce : valuesToReduce) { - Operation* computeOp = valToReduce.getParentBlock()->getParentOp(); - // if (valToReduce.getDefiningOp()) { - // // If the value is defined by an operation, we take the parent - // operation computeOp = valToReduce.getDefiningOp()->getParentOp(); - // } else { - // // Otherwise it is a block argument, - // computeOp->getBlock()->getParentOp(); - // } - - assert(isa(computeOp) && "Expected a ComputeOp"); - - auto it = lastValueForCompute.find(computeOp); - - if (it != lastValueForCompute.end()) { - // If we have already seen this computeOp, apply the reduction - // within-compute - Value lastWithinComputeValue = it->second; - - if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp())) - rewriter.setInsertionPointAfterValue(lastWithinComputeValue); - else - rewriter.setInsertionPointAfterValue(valToReduce); - valToReduce = reduce(lastWithinComputeValue, valToReduce); - lastValueForCompute[computeOp] = valToReduce; - } - - lastValueForCompute[computeOp] = valToReduce; - } - - // Now, reconstruct from the map the valuesToReduce list - valuesToReduce.clear(); - valuesToReduce.reserve(lastValueForCompute.size()); - for (auto& entry : lastValueForCompute) - valuesToReduce.push_back(entry.second); - - Location loc = valuesToReduce[0].getLoc(); - auto channelType = spatial::SpatChannelType::get(rewriter.getContext()); - - // Recursive algorithm to reduce the inputs to a single one: - // - Take two inputs at a time, and reduce them into a single one, updating - // the valuesToReduce list which becomes half the size. - // - Repeat until there is only one input left. - llvm::OwningArrayRef valuesToReduceRef(valuesToReduce); - while (valuesToReduceRef.size() > 1) { - SmallVector nextValuesToReduce; - nextValuesToReduce.reserve(valuesToReduceRef.size() / 2); - for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) { - auto firstValue = valuesToReduceRef[i]; - auto secondValue = valuesToReduceRef[i + 1]; - - auto firstCompute = firstValue.getParentBlock()->getParentOp(); - auto secondCompute = secondValue.getParentBlock()->getParentOp(); - - assert(isa(firstCompute)); - assert(isa(secondCompute)); - - if (secondCompute->isBeforeInBlock(firstCompute)) { - std::swap(firstValue, secondValue); - std::swap(firstCompute, secondCompute); - } - - // 1. Add a channel before the first computeOp - rewriter.setInsertionPoint(firstCompute); - auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType); - - // 2. Add a sendOp after the first value - rewriter.setInsertionPointAfterValue(firstValue); - spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue); - - // 3. Add a receiveOp after the second value - rewriter.setInsertionPointAfterValue(secondValue); - auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel); - - // 4. Apply reduction between second value and received value - rewriter.setInsertionPointAfterValue(receivedValue); - Value reduced = reduce(receivedValue, secondValue); - - nextValuesToReduce.push_back(reduced); - } - - // If we have an odd number of inputs, we need to add the last one to the - // newInputs list. - if (valuesToReduceRef.size() % 2 == 1) - nextValuesToReduce.push_back(valuesToReduceRef.back()); - - // Replace the inputOps list with the new one. - valuesToReduceRef = llvm::OwningArrayRef(std::move(nextValuesToReduce)); - } - - assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point."); - - auto finalValue = valuesToReduceRef[0]; - - if (postprocess) { - rewriter.setInsertionPointAfterValue(finalValue); - finalValue = postprocess(finalValue); - } - - return finalValue; -} - -template -bool hasPostProcessPoolingWindow() { - return false; -} - -template <> -bool hasPostProcessPoolingWindow() { - return true; -} - -template -Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter, - Location loc, - PoolOp poolOp, - Value valueToDivide, - size_t krn_size, - size_t tilesSkippedByPadding) { - return nullptr; -} - -template <> -Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter, - Location loc, - ONNXAveragePoolOp poolOp, - Value valueToDivide, - size_t krn_size, - size_t tilesSkippedByPadding) { - bool countIncludePad = poolOp.getCountIncludePad() == 1; - - size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding; - - RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type()); - - // Put a spat.const before the computeOp, and use its value. We do this to be - // compatible with the current code generation, which assumes constant to be - // loaded in global memory, which is allocated by adding a spat.const OP - // directly under func.func (i.e. alongside ComputeOps) - auto computeOp = cast(valueToDivide.getDefiningOp()->getParentOp()); - rewriter.setInsertionPoint(computeOp); - auto divisorValue = spatial::SpatConstantOp::create(rewriter, - loc, - scalarTensor, - rewriter.getI64IntegerAttr(divisorNumber), - /* should_allocate = */ rewriter.getBoolAttr(true)); - - rewriter.setInsertionPointAfterValue(valueToDivide); - return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue); -} - -template -struct PoolingBaseConverter : public OpConversionPattern { - PoolingBaseConverter(MLIRContext* ctx) - : OpConversionPattern(ctx) {} - - LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { - Value X = adaptor.getX(); - ShapedType xShape = mlir::cast(X.getType()); - Value Y = poolOp.getResult(); - ShapedType yShape = mlir::cast(Y.getType()); - - size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h; - unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y); - unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y); - unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h); - - if (adaptor.getAutoPad() != "NOTSET") - return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated."); - - size_t pad_x, pad_y; - auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y); - if (padUnpackError.has_value()) - return rewriter.notifyMatchFailure(poolOp, padUnpackError.value()); - - Location loc = poolOp.getLoc(); - - size_t input_h = getImageHeight(xShape); - size_t input_w = getImageWidth(xShape); - size_t output_h = getImageHeight(yShape); - size_t output_w = getImageWidth(yShape); - size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue()); - size_t channelTileRest = getImageChannel(xShape) % crossbarSize; - - // 1: Tile the input tensor - // Input tiles need to be indexed by: - // a. Channel Tile - // b. Pixel `x` position - // c. Pixel `y` position - // For example: inputTiles[channelTile][x][y] - // Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH) - // Suppose that the input tensor is produced by concatenating the results of - // many ComputeOps. Get the result tiles from these ComputeOps. - SmallVector>> inputTiles( - channelTileCount, SmallVector>(input_w, SmallVector(input_h))); - - auto resolveErrorOpt = - resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter); - if (resolveErrorOpt.has_value()) - return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt); - - // TODO: This requires a core for each input tile, which is not ideal. We - // can do better. - // If some input tiles come from the func.func operands, load - // them into a computeOp and yield them - for (size_t t = 0; t < channelTileCount; t++) { - for (size_t x = 0; x < input_w; x++) { - for (size_t y = 0; y < input_h; y++) { - if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp()) { - Location tileLoc = extractSliceOp.getLoc(); - - auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter, - tileLoc, - extractSliceOp.getResultType(), - /* xbarWeights =*/ValueRange(), - extractSliceOp.getResult()); - - Block* tempComputeOpBlock = new Block(); - tempComputeOp.getBody().push_back(tempComputeOpBlock); - auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc); - - rewriter.setInsertionPointToStart(tempComputeOpBlock); - spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg); - rewriter.setInsertionPointAfter(tempComputeOp); - inputTiles[t][x][y] = tempComputeOp.getResult(0); - } - } - } - } - - // 2: Tile the output tensor - // Output tiles need to be indexed by: - // a. Channel Tile - // b. Pixel `x` position - // c. Pixel `y` position - // For example: outputTiles[channelTile][x][y] - // Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH) - SmallVector>> outputTiles( - channelTileCount, SmallVector>(output_w, SmallVector(output_h, nullptr))); - - // List of values to pool for each output pixel - SmallVector valuesToPool; - - // Iterate each output tile - for (size_t outTile = 0; outTile < channelTileCount; outTile++) { - // Iterate each output pixel - for (size_t outX = 0; outX < output_w; outX++) { - for (size_t outY = 0; outY < output_h; outY++) { - - // Each output pixel tile is computed by pooling a window of input - // pixel tiles - valuesToPool.clear(); - size_t tilesSkippedByPadding = 0; - - auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x); - auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y); - - for (size_t inX = start_x; inX < end_x; inX += dilation_x) { - for (size_t inY = start_y; inY < end_y; inY += dilation_y) { - if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) { - tilesSkippedByPadding++; - continue; - } - - Value inputTile = inputTiles[outTile][inX][inY]; - - Value valueToPool; - if (auto computeProducer = inputTile.getDefiningOp()) { - - int resultNumber = getResultIndex(computeProducer, inputTile); - - auto yieldInComputeOp = cast(computeProducer.getBody().front().getTerminator()); - valueToPool = yieldInComputeOp.getOperand(resultNumber); - } - else if (auto receiveProducer = inputTile.getDefiningOp()) { - auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter); - if (failed(sendOpOpt)) { - return rewriter.notifyMatchFailure(poolOp, - "ChannelReceiveOp does not have a matching " - "ChannelSendOp."); - } - auto sendOp = cast(*sendOpOpt); - - valueToPool = sendOp.getData(); - } - else { - return rewriter.notifyMatchFailure(poolOp, - "Input tile for Pooling is not produced by a " - "WeightedComputeOp nor a receiveOp"); - } - - valuesToPool.push_back(valueToPool); - } - } - - assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense."); - // assert(computeOpsForPooling.size() != 1 && - // "Pooling computed on one tiles make no sense??? Or maybe - // this " "should have been simplified earlier???"); - - std::function postProcessFn = nullptr; - if (hasPostProcessPoolingWindow()) { - postProcessFn = [&](const Value prevFinalRes) { - return postProcessPoolingWindow( - rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding); - }; - } - - Value reducedWithinCompute = applyReducePatternNew( - valuesToPool, - rewriter, - [&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); }, - nullptr, - postProcessFn); - - // Send this value through a channel, and receive it in the - // `func.func`. During lowering, we will need to "move it" into the - // users computeOps - auto computeOpOfReduced = - cast(reducedWithinCompute.getDefiningOp()->getParentOp()); - - // Create a new channel before the computeOp - rewriter.setInsertionPoint(computeOpOfReduced); - auto reduceChannel = - spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext())); - - // Send value through the channel - rewriter.setInsertionPointAfterValue(reducedWithinCompute); - spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute); - - // Receive after the computeOp - rewriter.setInsertionPointAfter(computeOpOfReduced); - auto receivedValue = - spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel); - - outputTiles[outTile][outX][outY] = receivedValue; - } - } - } - - // TODO: outputTiles are not the results of the computeOps! We need to add - // them! - - std::unordered_map>> computeOpNeedingResults; - - // Iterate each output tile - for (size_t outTile = 0; outTile < channelTileCount; outTile++) { - // Iterate each output pixel - for (size_t outX = 0; outX < output_w; outX++) { - for (size_t outY = 0; outY < output_h; outY++) { - auto outputTile = outputTiles[outTile][outX][outY]; - auto outputTileProducer = outputTile.getDefiningOp()->getParentOp(); - if (!outputTileProducer) { - return rewriter.notifyMatchFailure(poolOp, - "Output tile for Pooling is not produced by a " - "WeightedComputeOp."); - } - - computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile)); - } - } - } - - Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType()); - - rewriter.replaceOp(poolOp, outputImage); - - return success(); - } -}; - -void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert>( - ctx); - patterns.insert>(ctx); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index 4dc16eb..c97932a 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -27,9 +27,21 @@ def spatToPimMVMOp : Pat< (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; -def spatToPimVAddOp : Pat< +def spatToPimVVAddOp : Pat< (SpatVAddOp:$srcOpRes $a, $b), - (PimVAddOp $a, $b, + (PimVVAddOp $a, $b, + (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) +>; + +def spatToPimVVMulOp : Pat< + (SpatVMulOp:$srcOpRes $a, $b), + (PimVVMulOp $a, $b, + (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) +>; + +def spatToPimVVMaxOp : Pat< + (SpatVMaxOp:$srcOpRes $a, $b), + (PimVVMaxOp $a, $b, (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 273288d..99f2a58 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -251,7 +251,7 @@ def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> { }]; } -def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> { +def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> { let description = [{ Element-wise addition: c = a + b }]; @@ -277,7 +277,59 @@ def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> { }]; } -def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods]> { +def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> { + let description = [{ + Element-wise subtraction: c = a - b + }]; + + let arguments = (ins + PimTensor: $a, + PimTensor: $b, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutBufMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes) + }]; +} + +def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> { + let description = [{ + Element-wise multiplication: c = a * b + }]; + + let arguments = (ins + PimTensor: $a, + PimTensor: $b, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutBufMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes) + }]; +} + +def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> { let description = [{ Element-wise max: c = max(a, b) }]; @@ -291,6 +343,32 @@ def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods` type($outRes) + }]; +} + +def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods]> { + let description = [{ + Dot product: c = dot(a, b) + }]; + + let arguments = (ins + PimTensor: $a, + PimTensor: $b, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); } def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods]> { @@ -332,14 +410,13 @@ def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods ); } -def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods]> { +def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods]> { let description = [{ - Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience) + Average all elements into a single one }]; let arguments = (ins - PimTensor: $dividend, - PimTensor: $divisor, + PimTensor: $a, PimTensor: $outBuf ); @@ -363,9 +440,24 @@ def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods]> { +def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods]> { let description = [{ - Element-wise exp: c = exp(a) + Element-wise tanh activation + }]; + + let arguments = (ins + PimTensor: $a, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); +} + +def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods]> { + let description = [{ + Element-wise sigmoid activation }]; let arguments = (ins @@ -388,4 +480,4 @@ def PimHaltOp: PimOp<"halt", [Terminator]> { }]; } -#endif // PIM_DIALECT_H \ No newline at end of file +#endif // PIM_DIALECT_H diff --git a/src/PIM/Dialect/Pim/PimOps.cpp b/src/PIM/Dialect/Pim/PimOps.cpp index 3ea7070..8e9d2f6 100644 --- a/src/PIM/Dialect/Pim/PimOps.cpp +++ b/src/PIM/Dialect/Pim/PimOps.cpp @@ -30,12 +30,13 @@ void PimDialect::initialize() { registerDependenciesFn(this->getOutBuf(), this->getResult()); \ } -POPULATE_DEPENDENCIES(PimVMaxOp) +POPULATE_DEPENDENCIES(PimVVDMulOp) POPULATE_DEPENDENCIES(PimApplyFiltersOp) POPULATE_DEPENDENCIES(PimSumOp) -POPULATE_DEPENDENCIES(PimVSDivOp) +POPULATE_DEPENDENCIES(PimVAvgOp) POPULATE_DEPENDENCIES(PimVReluOp) -POPULATE_DEPENDENCIES(PimVExpOp) +POPULATE_DEPENDENCIES(PimVTanhOp) +POPULATE_DEPENDENCIES(PimVSigmOp) } // namespace pim } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index acffded..55d886f 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -4,6 +4,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "OpBufferizationInterfaces.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; @@ -12,6 +13,26 @@ using namespace bufferization; namespace onnx_mlir { namespace pim { +static Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { + if (succeeded(resolveContiguousAddress(memrefValue))) + return memrefValue; + + auto shapedType = cast(memrefValue.getType()); + auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType()); + Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType); + auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8; + + return PimMemCopyOp::create(rewriter, + loc, + contiguousType, + contiguousBuffer, + memrefValue, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(sizeInBytes)) + .getDstOut(); +} + struct MemCopyHostToDevOpInterface : DstBufferizableOpInterfaceExternalModel { LogicalResult bufferize(Operation* op, @@ -164,7 +185,8 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel { +template +struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel, OpTy> { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return !cast(op).isDpsInit(&opOperand); } @@ -179,21 +201,24 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel(op); + auto binaryOp = cast(op); - auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state); + auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state); if (failed(aOpt)) return failure(); - auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state); + auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state); if (failed(bOpt)) return failure(); - auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state); + auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state); if (failed(outBufOpt)) return failure(); - replaceOpWithNewBufferizedOp(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt); + Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter); + Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter); + + replaceOpWithNewBufferizedOp(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt); return success(); } }; @@ -205,7 +230,10 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) { PimTransposeOp::attachInterface(*ctx); PimVMMOp::attachInterface(*ctx); PimMVMOp::attachInterface(*ctx); - PimVAddOp::attachInterface(*ctx); + PimVVAddOp::attachInterface>(*ctx); + PimVVSubOp::attachInterface>(*ctx); + PimVVMulOp::attachInterface>(*ctx); + PimVVMaxOp::attachInterface>(*ctx); }); } diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp index 51cdc10..304ef64 100644 --- a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp @@ -17,6 +17,7 @@ #include +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" @@ -36,6 +37,25 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& return memref::AllocOp::create(rewriter, loc, memrefResultType); } +Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { + if (succeeded(resolveContiguousAddress(memrefValue))) + return memrefValue; + + auto shapedType = cast(memrefValue.getType()); + auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter); + auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8; + + return pim::PimMemCopyOp::create(rewriter, + loc, + contiguousBuffer.getType(), + contiguousBuffer, + memrefValue, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(sizeInBytes)) + .getDstOut(); +} + const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id"); static FailureOr getOtherEndOfChannel(Operation* op, bool opIsReceive) { @@ -167,7 +187,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa auto memref = getBuffer(rewriter, operand, options, state); if (failed(memref)) return failure(); - memrefOperands.push_back(*memref); + memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter)); } // TODO: Support addiction with more than 2 operands @@ -460,7 +480,7 @@ struct ChannelBroadcastSendOpInterface }; struct VAddOpInterfaceFromTemplate -: VariadicArgumentElementWiseOpInterface {}; +: VariadicArgumentElementWiseOpInterface {}; struct WVMMOpInterface : WeightedMultiplicationsOpInterface {}; @@ -468,9 +488,7 @@ struct WMVMOpInterface : WeightedMultiplicationsOpInterface {}; -struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface {}; - -struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface {}; +struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface {}; // Create a new bufferizable op interface for the apply filters operation. struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel { @@ -557,7 +575,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { SpatWeightedVMMOp::attachInterface(*ctx); SpatWeightedMVMOp::attachInterface(*ctx); SpatSumOp::attachInterface(*ctx); - SpatVSDivOp::attachInterface(*ctx); SpatVMaxOp::attachInterface(*ctx); SpatChannelReceiveOp::attachInterface(*ctx); SpatChannelSendOp::attachInterface(*ctx); @@ -569,12 +586,16 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface {}; -struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface {}; +struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface {}; + +struct ONNXSigmoidInterface +: VariadicArgumentElementWiseOpInterface {}; void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) { ONNXReluOp::attachInterface(*ctx); - ONNXExpOp::attachInterface(*ctx); + ONNXTanhOp::attachInterface(*ctx); + ONNXSigmoidOp::attachInterface(*ctx); }); } diff --git a/src/PIM/Pass/CMakeLists.txt b/src/PIM/Pass/CMakeLists.txt index 8c4dfd6..bf509d1 100644 --- a/src/PIM/Pass/CMakeLists.txt +++ b/src/PIM/Pass/CMakeLists.txt @@ -5,7 +5,8 @@ add_pim_library(OMPimPasses PimConstantFolding/Patterns/Constant.cpp PimConstantFolding/PimConstantFoldingPass.cpp PimConstantFolding/Patterns/Subview.cpp - PimHostVerificationPass.cpp + PimMaterializeConstantsPass.cpp + PimVerificationPass.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Pass/PimConstantFolding/Patterns/Constant.cpp b/src/PIM/Pass/PimConstantFolding/Patterns/Constant.cpp index 3580589..d11b4c0 100644 --- a/src/PIM/Pass/PimConstantFolding/Patterns/Constant.cpp +++ b/src/PIM/Pass/PimConstantFolding/Patterns/Constant.cpp @@ -120,20 +120,8 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern { rewriter.setInsertionPoint(coreOp); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); - size_t elementByteWidth = initType.getElementTypeBitWidth() / 8; - if (elementByteWidth == 0) - return failure(); - size_t totalBytes = initType.getNumElements() * elementByteWidth; - rewriter.setInsertionPoint(mapOp); - pim::PimMemCopyHostToDevOp::create(rewriter, - mapOp.getLoc(), - initType, - mapOp.getInit(), - getGlobalOp.getResult(), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(static_cast(totalBytes))); + rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp); rewriter.eraseOp(mapOp); return success(); } diff --git a/src/PIM/Pass/PimMaterializeConstantsPass.cpp b/src/PIM/Pass/PimMaterializeConstantsPass.cpp new file mode 100644 index 0000000..74d809b --- /dev/null +++ b/src/PIM/Pass/PimMaterializeConstantsPass.cpp @@ -0,0 +1,135 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { + if (isa(op)) + return operandIndex == 1; + if (isa(op)) + return operandIndex == 0; + return false; +} + +static int64_t getValueSizeInBytes(Value value) { + auto type = dyn_cast(value.getType()); + if (!type || !type.hasStaticShape()) + return -1; + return type.getNumElements() * type.getElementTypeBitWidth() / 8; +} + +struct PimMaterializeConstantsPass + : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass) + + StringRef getArgument() const override { return "materialize-pim-constants"; } + StringRef getDescription() const override { + return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops"; + } + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + OpBuilder rewriter(moduleOp.getContext()); + bool hasFailure = false; + + for (func::FuncOp funcOp : moduleOp.getOps()) { + if (funcOp.isExternal()) + continue; + + for (pim::PimCoreOp coreOp : funcOp.getOps()) { + DenseMap>> materializedValues; + + for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) { + if (isa(op)) + continue; + + for (OpOperand& operand : op.getOpOperands()) { + Value originalValue = operand.get(); + if (!isa(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber())) + continue; + + auto resolvedAddress = resolveContiguousAddress(originalValue); + if (failed(resolvedAddress)) + continue; + + auto getGlobalOp = dyn_cast_or_null(resolvedAddress->base.getDefiningOp()); + if (!getGlobalOp) + continue; + + auto originalType = dyn_cast(originalValue.getType()); + if (!originalType || !originalType.hasStaticShape()) { + op.emitOpError("host constant materialization requires a static memref operand"); + hasFailure = true; + continue; + } + + auto& cachedByOffset = materializedValues[resolvedAddress->base]; + auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset]; + auto cachedValue = cachedByType.find(originalType); + if (cachedValue != cachedByType.end()) { + operand.set(cachedValue->second); + continue; + } + + int64_t totalBytes = getValueSizeInBytes(originalValue); + if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) { + op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets"); + hasFailure = true; + continue; + } + + auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType()); + + rewriter.setInsertionPoint(&op); + Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType); + Value deviceDst = localAlloc; + if (contiguousType != originalType) + deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc); + + auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter, + op.getLoc(), + originalType, + deviceDst, + getGlobalOp.getResult(), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr( + static_cast(resolvedAddress->byteOffset)), + rewriter.getI32IntegerAttr( + static_cast(totalBytes))); + + cachedByType[originalType] = hostToDevCopy.getResult(); + operand.set(hostToDevCopy.getResult()); + } + } + } + } + + if (hasFailure) { + signalPassFailure(); + return; + } + + dumpModule(moduleOp, "pim3_materialized"); + } +}; + +} // namespace + +std::unique_ptr createPimMaterializeConstantsPass() { + return std::make_unique(); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Pass/PimPasses.hpp b/src/PIM/Pass/PimPasses.hpp index 24681b0..7b59c8f 100644 --- a/src/PIM/Pass/PimPasses.hpp +++ b/src/PIM/Pass/PimPasses.hpp @@ -17,7 +17,9 @@ std::unique_ptr createBufferizePimPass(); std::unique_ptr createPimConstantFoldingPass(); -std::unique_ptr createPimHostVerificationPass(); +std::unique_ptr createPimMaterializeConstantsPass(); + +std::unique_ptr createPimVerificationPass(); std::unique_ptr createEmitPimJsonPass(); diff --git a/src/PIM/Pass/PimHostVerificationPass.cpp b/src/PIM/Pass/PimVerificationPass.cpp similarity index 75% rename from src/PIM/Pass/PimHostVerificationPass.cpp rename to src/PIM/Pass/PimVerificationPass.cpp index 827fa56..d0f2ed7 100644 --- a/src/PIM/Pass/PimHostVerificationPass.cpp +++ b/src/PIM/Pass/PimVerificationPass.cpp @@ -35,16 +35,24 @@ static bool isCodegenAddressableValue(Value value) { || isa(resolvedAddress->base.getDefiningOp()); } -struct PimHostVerificationPass : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass) +static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { + if (isa(op)) + return operandIndex == 1; + if (isa(op)) + return operandIndex == 0; + return false; +} - StringRef getArgument() const override { return "verify-pim-host-pass"; } +struct PimVerificationPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass) + + StringRef getArgument() const override { return "verify-pim-pass"; } StringRef getDescription() const override { - return "Verify that no runtime host-side code remains in bufferized PIM IR"; + return "Verify that bufferized PIM IR contains only explicit host/device transfers"; } - PimHostVerificationPass() {} - PimHostVerificationPass(const PimHostVerificationPass& pass) {} + PimVerificationPass() {} + PimVerificationPass(const PimVerificationPass& pass) {} void runOnOperation() override { ModuleOp moduleOp = getOperation(); @@ -132,11 +140,27 @@ private: for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) { if (!isa(operand.getType())) continue; - if (succeeded(resolveContiguousAddress(operand))) - continue; - op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage"; - hasFailure = true; + auto resolvedAddress = resolveContiguousAddress(operand); + if (failed(resolvedAddress)) { + op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage"; + hasFailure = true; + continue; + } + + if (isExplicitHostOperand(&op, operandIndex)) { + if (!isCodegenAddressableValue(operand)) { + op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage"; + hasFailure = true; + } + continue; + } + + if (!isa(resolvedAddress->base.getDefiningOp())) { + op.emitOpError() << "operand #" << operandIndex + << " must be backed by device-local memory; materialize host values with pim.memcp_hd"; + hasFailure = true; + } } } return success(!hasFailure); @@ -165,6 +189,6 @@ private: } // namespace -std::unique_ptr createPimHostVerificationPass() { return std::make_unique(); } +std::unique_ptr createPimVerificationPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp index 4f29e7b..6703ecf 100644 --- a/src/PIM/PimAccelerator.cpp +++ b/src/PIM/PimAccelerator.cpp @@ -74,7 +74,8 @@ void PimAccelerator::registerPasses(int optLevel) const { registerPass(createSpatialToPimPass); registerPass(createBufferizePimPass); registerPass(createPimConstantFoldingPass); - registerPass(createPimHostVerificationPass); + registerPass(createPimMaterializeConstantsPass); + registerPass(createPimVerificationPass); registerPass(createEmitPimJsonPass); } diff --git a/validation/operations/README.md b/validation/operations/README.md index 03c2a5c..cd6e88e 100644 --- a/validation/operations/README.md +++ b/validation/operations/README.md @@ -23,6 +23,18 @@ python3 validation/operations/gen_tests.py | With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias | | Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input | +## Pool + +| Test | Directory | Input | Output | Kernel | Stride | Padding | Notes | +|------|-----------|-------|--------|--------|--------|---------|-------| +| Max basic | `pool/max_basic` | [1,1,4,4] | [1,1,3,3] | 2x2 | 1 | none | Basic max pooling | +| Max stride 2 multi-channel | `pool/max_stride2_multichannel` | [1,5,6,6] | [1,5,3,3] | 2x2 | 2 | none | Channel-preserving max pool | +| Max SAME_UPPER | `pool/max_same_upper` | [1,1,5,5] | [1,1,3,3] | 3x3 | 2 | SAME_UPPER | Deprecated auto_pad path | +| Avg basic | `pool/avg_basic` | [1,3,4,4] | [1,3,3,3] | 2x2 | 1 | none | Basic average pooling | +| Avg explicit padding | `pool/avg_explicit_padding` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=0` | +| Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` | +| Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` | + ## Gemm | Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes | diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py index 812edec..526c4c9 100644 --- a/validation/operations/gen_tests.py +++ b/validation/operations/gen_tests.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Generate ONNX test models for validating GEMM and Conv implementations.""" +"""Generate ONNX test models for validating GEMM, Conv, and Pooling implementations.""" import numpy as np import onnx @@ -248,6 +248,85 @@ def conv_large_spatial(): save_model(model, "conv/large_spatial", "conv_large_spatial.onnx") +# --------------------------------------------------------------------------- +# Pooling tests +# --------------------------------------------------------------------------- + +def maxpool_basic(): + """MaxPool 2x2 with stride 1.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3]) + node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0]) + graph = helper.make_graph([node], "maxpool_basic", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "pool/max_basic", "maxpool_basic.onnx") + + +def maxpool_stride2_multichannel(): + """MaxPool 2x2 with stride 2 on multiple channels.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 5, 6, 6]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 5, 3, 3]) + node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0]) + graph = helper.make_graph([node], "maxpool_stride2_multichannel", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "pool/max_stride2_multichannel", "maxpool_stride2_multichannel.onnx") + + +def maxpool_same_upper(): + """MaxPool 3x3 with SAME_UPPER padding.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3]) + node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[3, 3], strides=[2, 2], auto_pad="SAME_UPPER") + graph = helper.make_graph([node], "maxpool_same_upper", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "pool/max_same_upper", "maxpool_same_upper.onnx") + + +def avgpool_basic(): + """AveragePool 2x2 with stride 1.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 3, 3]) + node = helper.make_node("AveragePool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0]) + graph = helper.make_graph([node], "avgpool_basic", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "pool/avg_basic", "avgpool_basic.onnx") + + +def avgpool_explicit_padding(): + """AveragePool 3x3 with explicit padding, excluding pad from the divisor.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2]) + node = helper.make_node("AveragePool", ["X"], ["Y"], + kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=0) + graph = helper.make_graph([node], "avgpool_explicit_padding", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "pool/avg_explicit_padding", "avgpool_explicit_padding.onnx") + + +def avgpool_include_pad(): + """AveragePool 3x3 with explicit padding, including pad in the divisor.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2]) + node = helper.make_node("AveragePool", ["X"], ["Y"], + kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=1) + graph = helper.make_graph([node], "avgpool_include_pad", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "pool/avg_include_pad", "avgpool_include_pad.onnx") + + +def maxpool_after_conv(): + """Conv followed by MaxPool to validate pooling on lowered conv results.""" + rng = np.random.default_rng(59) + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 6, 6]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 2, 2]) + W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W") + conv = helper.make_node("Conv", ["X", "W"], ["C"], kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0]) + pool = helper.make_node("MaxPool", ["C"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0]) + graph = helper.make_graph([conv, pool], "maxpool_after_conv", [X], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx") + + # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- @@ -273,4 +352,13 @@ if __name__ == "__main__": conv_batch_2() conv_large_spatial() + print("\nGenerating Pooling tests:") + maxpool_basic() + maxpool_stride2_multichannel() + maxpool_same_upper() + avgpool_basic() + avgpool_explicit_padding() + avgpool_include_pad() + maxpool_after_conv() + print("\nDone.") diff --git a/validation/operations/pool/avg_basic/avgpool_basic.onnx b/validation/operations/pool/avg_basic/avgpool_basic.onnx new file mode 100644 index 0000000..69a1b8a Binary files /dev/null and b/validation/operations/pool/avg_basic/avgpool_basic.onnx differ diff --git a/validation/operations/pool/avg_explicit_padding/avgpool_explicit_padding.onnx b/validation/operations/pool/avg_explicit_padding/avgpool_explicit_padding.onnx new file mode 100644 index 0000000..edc1f1f Binary files /dev/null and b/validation/operations/pool/avg_explicit_padding/avgpool_explicit_padding.onnx differ diff --git a/validation/operations/pool/avg_include_pad/avgpool_include_pad.onnx b/validation/operations/pool/avg_include_pad/avgpool_include_pad.onnx new file mode 100644 index 0000000..1a1a726 Binary files /dev/null and b/validation/operations/pool/avg_include_pad/avgpool_include_pad.onnx differ diff --git a/validation/operations/pool/max_after_conv/maxpool_after_conv.onnx b/validation/operations/pool/max_after_conv/maxpool_after_conv.onnx new file mode 100644 index 0000000..da54062 Binary files /dev/null and b/validation/operations/pool/max_after_conv/maxpool_after_conv.onnx differ diff --git a/validation/operations/pool/max_basic/maxpool_basic.onnx b/validation/operations/pool/max_basic/maxpool_basic.onnx new file mode 100644 index 0000000..cd248ba Binary files /dev/null and b/validation/operations/pool/max_basic/maxpool_basic.onnx differ diff --git a/validation/operations/pool/max_same_upper/maxpool_same_upper.onnx b/validation/operations/pool/max_same_upper/maxpool_same_upper.onnx new file mode 100644 index 0000000..37c0819 Binary files /dev/null and b/validation/operations/pool/max_same_upper/maxpool_same_upper.onnx differ diff --git a/validation/operations/pool/max_stride2_multichannel/maxpool_stride2_multichannel.onnx b/validation/operations/pool/max_stride2_multichannel/maxpool_stride2_multichannel.onnx new file mode 100644 index 0000000..77745c3 Binary files /dev/null and b/validation/operations/pool/max_stride2_multichannel/maxpool_stride2_multichannel.onnx differ