reimplement pool lowering

add pool validation
align PIM ops/codegen/parser with the ISA
move constant materialization to MLIR
rename the PIM verification/materialization passes
better folded-constant handling
This commit is contained in:
NiccoloN
2026-03-23 19:14:50 +01:00
parent 461bdd808d
commit 661170a9aa
30 changed files with 912 additions and 512 deletions

View File

@@ -530,6 +530,7 @@ where
let r2_val = r2; let r2_val = r2;
ensure!(r2_val == 1, "Stride different than 1 not supported"); ensure!(r2_val == 1, "Stride different than 1 not supported");
let rd_val = core.register(rd); let rd_val = core.register(rd);
ensure!(offset_select == 1, "Offset select cannot be different from 1");
let r1_val = add_offset_r1(r1_val, offset_select, offset_value); let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?; let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
let load1 = loads[0]; let load1 = loads[0];

View File

@@ -224,7 +224,21 @@ fn json_to_vvsub(
inst_data_builder: &mut InstructionDataBuilder, inst_data_builder: &mut InstructionDataBuilder,
json: &Value, json: &Value,
) -> Result<()> { ) -> Result<()> {
todo!("Not present in the compiler"); let json = json.as_object().expect("Not an object");
assert_eq!("vvsub", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let rs2 = json_i64!(json, "rs2") as i32;
let len = json_i64!(json, "len") as i32;
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_r1(rs1)
.set_r2(rs2)
.set_imm_len(len)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
inst_builder.make_inst(vvsub, inst_data_builder.build());
Ok(()) Ok(())
} }
@@ -256,7 +270,21 @@ fn json_to_vvdmul(
inst_data_builder: &mut InstructionDataBuilder, inst_data_builder: &mut InstructionDataBuilder,
json: &Value, json: &Value,
) -> Result<()> { ) -> Result<()> {
todo!("Not present in the compiler"); let json = json.as_object().expect("Not an object");
assert_eq!("vvdmul", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let rs2 = json_i64!(json, "rs2") as i32;
let len = json_i64!(json, "len") as i32;
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_r1(rs1)
.set_r2(rs2)
.set_imm_len(len)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
inst_builder.make_inst(vvdmul, inst_data_builder.build());
Ok(()) Ok(())
} }
@@ -306,7 +334,21 @@ fn json_to_vavg(
inst_data_builder: &mut InstructionDataBuilder, inst_data_builder: &mut InstructionDataBuilder,
json: &Value, json: &Value,
) -> Result<()> { ) -> Result<()> {
todo!("Not present in the compiler"); let json = json.as_object().expect("Not an object");
assert_eq!("vavg", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let rs2 = json_i64!(json, "rs2") as i32;
let len = json_i64!(json, "len") as i32;
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_r1(rs1)
.set_r2(rs2)
.set_imm_len(len)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
inst_builder.make_inst(vavg, inst_data_builder.build());
Ok(()) Ok(())
} }
@@ -358,7 +400,7 @@ fn json_to_vsigm(
json: &Value, json: &Value,
) -> Result<()> { ) -> Result<()> {
let json = json.as_object().expect("Not an object"); let json = json.as_object().expect("Not an object");
assert_eq!("vsigmoid", json_str!(json, "op")); assert_eq!("vsigm", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32; let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32; let rs1 = json_i64!(json, "rs1") as i32;
let len = json_i64!(json, "len") as i32; let len = json_i64!(json, "len") as i32;

View File

@@ -237,14 +237,16 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeM
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix) // TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
} }
void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const { static size_t getValueSizeInBytes(mlir::Value value) {
auto outBufAddr = memory.getValueAddress(vaddOp.getOutBuf()); auto type = cast<ShapedType>(value.getType());
auto aAddr = memory.getValueAddress(vaddOp.getA()); return type.getNumElements() * type.getElementTypeBitWidth() / 8;
auto bAddr = memory.getValueAddress(vaddOp.getB()); }
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
auto outputType = cast<MemRefType>(vaddOp.getOutBuf().getType()); void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
size_t totalBytes = outputType.getNumElements() * vaddOp.getOutRes().getType().getElementTypeBitWidth() / 8; auto outBufAddr = memory.getValueAddress(vvaddOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvaddOp.getA());
auto bAddr = memory.getValueAddress(vvaddOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json; json::Object json;
json["op"] = "vvadd"; json["op"] = "vvadd";
@@ -252,14 +254,46 @@ void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
json["rs1"] = 1; json["rs1"] = 1;
json["rs2"] = 2; json["rs2"] = 2;
json["offset"] = createEmptyOffset(); json["offset"] = createEmptyOffset();
json["len"] = totalBytes; json["len"] = getValueSizeInBytes(vvaddOp.getA());
emitInstruction(std::move(json)); emitInstruction(std::move(json));
} }
void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const { void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
auto outBufAddr = memory.getValueAddress(vmaxOp.getOutBuf()); auto outBufAddr = memory.getValueAddress(vvsubOp.getOutBuf());
auto aAddr = memory.getValueAddress(vmaxOp.getA()); auto aAddr = memory.getValueAddress(vvsubOp.getA());
auto bAddr = memory.getValueAddress(vmaxOp.getB()); auto bAddr = memory.getValueAddress(vvsubOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvsub";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvsubOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
auto outBufAddr = memory.getValueAddress(vvmulOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvmulOp.getA());
auto bAddr = memory.getValueAddress(vvmulOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvmul";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvmulOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
auto outBufAddr = memory.getValueAddress(vvmaxOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvmaxOp.getA());
auto bAddr = memory.getValueAddress(vvmaxOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json; json::Object json;
@@ -268,6 +302,37 @@ void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
json["rs1"] = 1; json["rs1"] = 1;
json["rs2"] = 2; json["rs2"] = 2;
json["offset"] = createEmptyOffset(); json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvmaxOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
auto outBufAddr = memory.getValueAddress(vvdmulOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvdmulOp.getA());
auto bAddr = memory.getValueAddress(vvdmulOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvdmul";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvdmulOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
auto outBufAddr = memory.getValueAddress(vavgOp.getOutBuf());
auto aAddr = memory.getValueAddress(vavgOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
json::Object json;
json["op"] = "vavg";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vavgOp.getA());
emitInstruction(std::move(json)); emitInstruction(std::move(json));
} }
@@ -281,6 +346,35 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
json["rd"] = 0; json["rd"] = 0;
json["rs1"] = 1; json["rs1"] = 1;
json["offset"] = createEmptyOffset(); json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vreluOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
auto outBufAddr = memory.getValueAddress(vtanhOp.getOutBuf());
auto aAddr = memory.getValueAddress(vtanhOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
json::Object json;
json["op"] = "vtanh";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vtanhOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
auto outBufAddr = memory.getValueAddress(vsigmOp.getOutBuf());
auto aAddr = memory.getValueAddress(vsigmOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
json::Object json;
json["op"] = "vsigm";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vsigmOp.getA());
emitInstruction(std::move(json)); emitInstruction(std::move(json));
} }
@@ -338,6 +432,7 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co
vaddJson["rs1"] = 1; vaddJson["rs1"] = 1;
vaddJson["rs2"] = 2; vaddJson["rs2"] = 2;
vaddJson["offset"] = createEmptyOffset(); vaddJson["offset"] = createEmptyOffset();
vaddJson["len"] = 32 * outChannels;
emitInstruction(std::move(vaddJson)); emitInstruction(std::move(vaddJson));
} }
} }
@@ -479,13 +574,25 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp); coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp); coreCodeGen.codeGenTransposeOp(transposeOp);
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op)) else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
coreCodeGen.codeGenVAddOp(vaddOp); coreCodeGen.codeGenVVAddOp(vvaddOp);
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op)) else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
coreCodeGen.codeGenVMaxOp(vmaxOp); coreCodeGen.codeGenVVSubOp(vvsubOp);
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
coreCodeGen.codeGenVVMulOp(vvmulOp);
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
coreCodeGen.codeGenVVMaxOp(vvmaxOp);
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
coreCodeGen.codeGenVVDMulOp(vvdmulOp);
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
coreCodeGen.codeGenVAvgOp(vavgOp);
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op)) else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
coreCodeGen.codeGenVReluOp(vreluOp); coreCodeGen.codeGenVReluOp(vreluOp);
else if (isa<pim::PimSumOp, pim::PimVSDivOp, pim::PimVExpOp>(op)) { else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
coreCodeGen.codeGenVTanhOp(vtanhOp);
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
coreCodeGen.codeGenVSigmOp(vsigmOp);
else if (isa<pim::PimSumOp>(op)) {
// TODO: Implement somehow? // TODO: Implement somehow?
op.emitWarning("Operation is not yet supported in code generation"); op.emitWarning("Operation is not yet supported in code generation");
continue; continue;

View File

@@ -90,9 +90,15 @@ public:
template <typename MVMTy> template <typename MVMTy>
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix); void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
void codeGenVAddOp(pim::PimVAddOp vaddOp) const; void codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const;
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const; void codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const;
void codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const;
void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const;
void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const;
void codeGenVAvgOp(pim::PimVAvgOp vavgOp) const;
void codeGenVReluOp(pim::PimVReluOp vreluOp) const; void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const; void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const; void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
}; };

View File

@@ -47,8 +47,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitPimCodegen) { if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimConstantFoldingPass()); pm.addPass(createPimConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded")); pm.addPass(createMessagePass("Pim constants folded"));
pm.addPass(createPimHostVerificationPass()); pm.addPass(createPimMaterializeConstantsPass());
pm.addPass(createMessagePass("Pim host verified")); pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass()); pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim json code emitted")); pm.addPass(createMessagePass("Pim json code emitted"));

View File

@@ -6,7 +6,7 @@ add_pim_library(OMONNXToSpatial
Patterns/Math/Gemm.cpp Patterns/Math/Gemm.cpp
Patterns/Math/Conv.cpp Patterns/Math/Conv.cpp
Patterns/Math/MatMul.cpp Patterns/Math/MatMul.cpp
Patterns/NN/Pooling.cpp Patterns/NN/Pool.cpp
Patterns/NN/ReduceMean.cpp Patterns/NN/ReduceMean.cpp
Patterns/Tensor/Concat.cpp Patterns/Tensor/Concat.cpp
Patterns/Tensor/Reshape.cpp Patterns/Tensor/Reshape.cpp

View File

@@ -93,7 +93,7 @@ void ONNXToSpatialPass::runOnOperation() {
patterns.add<removeLRNPattern>(ctx); patterns.add<removeLRNPattern>(ctx);
populateConvOpPatterns(patterns, ctx); populateConvOpPatterns(patterns, ctx);
populatePoolingTilingPattern(patterns, ctx); populatePoolTilingPattern(patterns, ctx);
populateOnnxGemmOpPatterns(patterns, ctx); populateOnnxGemmOpPatterns(patterns, ctx);
populateReshapeConversionPattern(patterns, ctx); populateReshapeConversionPattern(patterns, ctx);

View File

@@ -11,7 +11,7 @@ void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIR
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);

View File

@@ -0,0 +1,265 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include <optional>
#include <type_traits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename ArrayAttrT>
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
return cast<IntegerAttr>(arrayAttr[index]).getInt();
}
template <typename ArrayAttrT>
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
}
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
assert(!values.empty() && "Expected at least one value to concatenate.");
if (values.size() == 1)
return values.front();
return tensor::ConcatOp::create(rewriter, loc, axis, values);
}
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
sizes.reserve(tileType.getRank());
for (int64_t dimSize : tileType.getShape())
sizes.push_back(rewriter.getIndexAttr(dimSize));
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
}
template <typename ReduceOp>
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
assert(!windowValues.empty() && "Expected at least one pool window value.");
Value reduced = windowValues.front();
for (Value value : windowValues.drop_front())
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
return reduced;
}
static Value
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
assert(divisor > 0 && "AveragePool divisor must be positive.");
if (divisor == 1)
return reducedWindow;
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
double scale = 1.0 / static_cast<double>(divisor);
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
}
template <typename PoolOp>
struct PoolToSpatialCompute;
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
using OpConversionPattern<PoolOp>::OpConversionPattern;
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Location loc = poolOp.getLoc();
Value x = adaptor.getX();
auto xType = dyn_cast<RankedTensorType>(x.getType());
auto outType = dyn_cast<RankedTensorType>(poolOp.getResult().getType());
if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape())
return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types.");
if (xType.getRank() != 4 || outType.getRank() != 4)
return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported.");
ArrayAttr kernelAttr = poolOp.getKernelShape();
if (!kernelAttr || kernelAttr.size() != 2)
return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel.");
const int64_t batchSize = xType.getDimSize(0);
const int64_t channels = xType.getDimSize(1);
const int64_t inputHeight = xType.getDimSize(2);
const int64_t inputWidth = xType.getDimSize(3);
const int64_t outputHeight = outType.getDimSize(2);
const int64_t outputWidth = outType.getDimSize(3);
const int64_t kernelHeight = getI64(kernelAttr, 0);
const int64_t kernelWidth = getI64(kernelAttr, 1);
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
int64_t padTop = 0;
int64_t padLeft = 0;
int64_t padBottom = 0;
int64_t padRight = 0;
if (auto padsAttr = poolOp.getPads()) {
if (padsAttr->size() != 4)
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
padTop = getI64(*padsAttr, 0);
padLeft = getI64(*padsAttr, 1);
padBottom = getI64(*padsAttr, 2);
padRight = getI64(*padsAttr, 3);
}
else {
StringRef autoPad = poolOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max<int64_t>(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight);
const int64_t totalPadW = std::max<int64_t>(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth);
if (autoPad == "SAME_UPPER") {
padTop = totalPadH / 2;
padBottom = totalPadH - padTop;
padLeft = totalPadW / 2;
padRight = totalPadW - padLeft;
}
else {
padBottom = totalPadH / 2;
padTop = totalPadH - padBottom;
padRight = totalPadW / 2;
padLeft = totalPadW - padRight;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value.");
}
}
(void) padBottom;
(void) padRight;
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x});
auto* computeBlock = new Block();
computeBlock->addArgument(xType, loc);
computeOp.getBody().push_back(computeBlock);
rewriter.setInsertionPointToStart(computeBlock);
Value input = computeBlock->getArgument(0);
SmallVector<Value> batchResults;
batchResults.reserve(batchSize);
for (int64_t batch = 0; batch < batchSize; ++batch) {
SmallVector<Value> rows;
rows.reserve(outputHeight);
for (int64_t outH = 0; outH < outputHeight; ++outH) {
SmallVector<Value> rowPixels;
rowPixels.reserve(outputWidth);
for (int64_t outW = 0; outW < outputWidth; ++outW) {
SmallVector<Value> outputChannelTiles;
outputChannelTiles.reserve(channelTileCount);
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
SmallVector<Value> windowValues;
windowValues.reserve(kernelHeight * kernelWidth);
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
if (inH < 0 || inH >= inputHeight)
continue;
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (inW < 0 || inW >= inputWidth)
continue;
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
rewriter.getIndexAttr(channelTile * xbarSize),
rewriter.getIndexAttr(inH),
rewriter.getIndexAttr(inW)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
windowValues.push_back(windowValue);
}
}
if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
const int64_t divisor =
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
}
outputChannelTiles.push_back(reducedWindow);
}
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
}
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
}
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
}
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
rewriter.replaceOp(poolOp, computeOp.getResult(0));
return success();
}
};
template <>
struct PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>
: public PoolToSpatialComputeBase<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp> {
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
};
template <>
struct PoolToSpatialCompute<ONNXAveragePoolOp>
: public PoolToSpatialComputeBase<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp> {
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
};
} // namespace
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,427 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
ConversionPatternRewriter& rewriter,
std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value&)> preprocess,
std::function<Value(const Value&)> postprocess) {
// Simple case: if we have only one input, just return it
if (valuesToReduce.size() == 1)
return valuesToReduce[0];
if (preprocess) {
for (auto& valToReduce : valuesToReduce) {
rewriter.setInsertionPointAfterValue(valToReduce);
valToReduce = preprocess(valToReduce);
}
}
// It is possible that `valuesToReduce` contains two entries for the same
// computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation*, Value> lastValueForCompute;
for (auto& valToReduce : valuesToReduce) {
Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
// if (valToReduce.getDefiningOp()) {
// // If the value is defined by an operation, we take the parent
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
// } else {
// // Otherwise it is a block argument,
// computeOp->getBlock()->getParentOp();
// }
assert(isa<spatial::SpatWeightedCompute>(computeOp) && "Expected a ComputeOp");
auto it = lastValueForCompute.find(computeOp);
if (it != lastValueForCompute.end()) {
// If we have already seen this computeOp, apply the reduction
// within-compute
Value lastWithinComputeValue = it->second;
if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
else
rewriter.setInsertionPointAfterValue(valToReduce);
valToReduce = reduce(lastWithinComputeValue, valToReduce);
lastValueForCompute[computeOp] = valToReduce;
}
lastValueForCompute[computeOp] = valToReduce;
}
// Now, reconstruct from the map the valuesToReduce list
valuesToReduce.clear();
valuesToReduce.reserve(lastValueForCompute.size());
for (auto& entry : lastValueForCompute)
valuesToReduce.push_back(entry.second);
Location loc = valuesToReduce[0].getLoc();
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
// Recursive algorithm to reduce the inputs to a single one:
// - Take two inputs at a time, and reduce them into a single one, updating
// the valuesToReduce list which becomes half the size.
// - Repeat until there is only one input left.
llvm::OwningArrayRef<Value> valuesToReduceRef(valuesToReduce);
while (valuesToReduceRef.size() > 1) {
SmallVector<Value> nextValuesToReduce;
nextValuesToReduce.reserve(valuesToReduceRef.size() / 2);
for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) {
auto firstValue = valuesToReduceRef[i];
auto secondValue = valuesToReduceRef[i + 1];
auto firstCompute = firstValue.getParentBlock()->getParentOp();
auto secondCompute = secondValue.getParentBlock()->getParentOp();
assert(isa<spatial::SpatWeightedCompute>(firstCompute));
assert(isa<spatial::SpatWeightedCompute>(secondCompute));
if (secondCompute->isBeforeInBlock(firstCompute)) {
std::swap(firstValue, secondValue);
std::swap(firstCompute, secondCompute);
}
// 1. Add a channel before the first computeOp
rewriter.setInsertionPoint(firstCompute);
auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
// 2. Add a sendOp after the first value
rewriter.setInsertionPointAfterValue(firstValue);
spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue);
// 3. Add a receiveOp after the second value
rewriter.setInsertionPointAfterValue(secondValue);
auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel);
// 4. Apply reduction between second value and received value
rewriter.setInsertionPointAfterValue(receivedValue);
Value reduced = reduce(receivedValue, secondValue);
nextValuesToReduce.push_back(reduced);
}
// If we have an odd number of inputs, we need to add the last one to the
// newInputs list.
if (valuesToReduceRef.size() % 2 == 1)
nextValuesToReduce.push_back(valuesToReduceRef.back());
// Replace the inputOps list with the new one.
valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
}
assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
auto finalValue = valuesToReduceRef[0];
if (postprocess) {
rewriter.setInsertionPointAfterValue(finalValue);
finalValue = postprocess(finalValue);
}
return finalValue;
}
template <typename PoolOp>
bool hasPostProcessPoolingWindow() {
return false;
}
template <>
bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
return true;
}
template <typename PoolOp>
Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr;
}
template <>
Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
Location loc,
ONNXAveragePoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp);
auto divisorValue = spatial::SpatConstantOp::create(rewriter,
loc,
scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide);
return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue);
}
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
PoolingBaseConverter(MLIRContext* ctx)
: OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult();
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET")
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
size_t pad_x, pad_y;
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
Location loc = poolOp.getLoc();
size_t input_h = getImageHeight(xShape);
size_t input_w = getImageWidth(xShape);
size_t output_h = getImageHeight(yShape);
size_t output_w = getImageWidth(yShape);
size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
size_t channelTileRest = getImageChannel(xShape) % crossbarSize;
// 1: Tile the input tensor
// Input tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: inputTiles[channelTile][x][y]
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
// Suppose that the input tensor is produced by concatenating the results of
// many ComputeOps. Get the result tiles from these ComputeOps.
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt =
resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
if (resolveErrorOpt.has_value())
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
// TODO: This requires a core for each input tile, which is not ideal. We
// can do better.
// If some input tiles come from the func.func operands, load
// them into a computeOp and yield them
for (size_t t = 0; t < channelTileCount; t++) {
for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) {
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
Location tileLoc = extractSliceOp.getLoc();
auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter,
tileLoc,
extractSliceOp.getResultType(),
/* xbarWeights =*/ValueRange(),
extractSliceOp.getResult());
Block* tempComputeOpBlock = new Block();
tempComputeOp.getBody().push_back(tempComputeOpBlock);
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
rewriter.setInsertionPointToStart(tempComputeOpBlock);
spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg);
rewriter.setInsertionPointAfter(tempComputeOp);
inputTiles[t][x][y] = tempComputeOp.getResult(0);
}
}
}
}
// 2: Tile the output tensor
// Output tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: outputTiles[channelTile][x][y]
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
// List of values to pool for each output pixel
SmallVector<Value> valuesToPool;
// Iterate each output tile
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
// Iterate each output pixel
for (size_t outX = 0; outX < output_w; outX++) {
for (size_t outY = 0; outY < output_h; outY++) {
// Each output pixel tile is computed by pooling a window of input
// pixel tiles
valuesToPool.clear();
size_t tilesSkippedByPadding = 0;
auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
tilesSkippedByPadding++;
continue;
}
Value inputTile = inputTiles[outTile][inX][inY];
Value valueToPool;
if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
int resultNumber = getResultIndex(computeProducer, inputTile);
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
valueToPool = yieldInComputeOp.getOperand(resultNumber);
}
else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
if (failed(sendOpOpt)) {
return rewriter.notifyMatchFailure(poolOp,
"ChannelReceiveOp does not have a matching "
"ChannelSendOp.");
}
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
valueToPool = sendOp.getData();
}
else {
return rewriter.notifyMatchFailure(poolOp,
"Input tile for Pooling is not produced by a "
"WeightedComputeOp nor a receiveOp");
}
valuesToPool.push_back(valueToPool);
}
}
assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
// assert(computeOpsForPooling.size() != 1 &&
// "Pooling computed on one tiles make no sense??? Or maybe
// this " "should have been simplified earlier???");
std::function<Value(const Value&)> postProcessFn = nullptr;
if (hasPostProcessPoolingWindow<PoolOp>()) {
postProcessFn = [&](const Value prevFinalRes) {
return postProcessPoolingWindow(
rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
};
}
Value reducedWithinCompute = applyReducePatternNew(
valuesToPool,
rewriter,
[&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); },
nullptr,
postProcessFn);
// Send this value through a channel, and receive it in the
// `func.func`. During lowering, we will need to "move it" into the
// users computeOps
auto computeOpOfReduced =
cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
// Create a new channel before the computeOp
rewriter.setInsertionPoint(computeOpOfReduced);
auto reduceChannel =
spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext()));
// Send value through the channel
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute);
// Receive after the computeOp
rewriter.setInsertionPointAfter(computeOpOfReduced);
auto receivedValue =
spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel);
outputTiles[outTile][outX][outY] = receivedValue;
}
}
}
// TODO: outputTiles are not the results of the computeOps! We need to add
// them!
std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
// Iterate each output tile
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
// Iterate each output pixel
for (size_t outX = 0; outX < output_w; outX++) {
for (size_t outY = 0; outY < output_h; outY++) {
auto outputTile = outputTiles[outTile][outX][outY];
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
if (!outputTileProducer) {
return rewriter.notifyMatchFailure(poolOp,
"Output tile for Pooling is not produced by a "
"WeightedComputeOp.");
}
computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
}
}
}
Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
rewriter.replaceOp(poolOp, outputImage);
return success();
}
};
void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
ctx);
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
}
} // namespace onnx_mlir

View File

@@ -27,9 +27,21 @@ def spatToPimMVMOp : Pat<
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatToPimVAddOp : Pat< def spatToPimVVAddOp : Pat<
(SpatVAddOp:$srcOpRes $a, $b), (SpatVAddOp:$srcOpRes $a, $b),
(PimVAddOp $a, $b, (PimVVAddOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVMulOp : Pat<
(SpatVMulOp:$srcOpRes $a, $b),
(PimVVMulOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVMaxOp : Pat<
(SpatVMaxOp:$srcOpRes $a, $b),
(PimVVMaxOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;

View File

@@ -251,7 +251,7 @@ def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
}]; }];
} }
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> { def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> {
let description = [{ let description = [{
Element-wise addition: c = a + b Element-wise addition: c = a + b
}]; }];
@@ -277,7 +277,59 @@ def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
}]; }];
} }
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> { def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> {
let description = [{
Element-wise subtraction: c = a - b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> {
let description = [{
Element-wise multiplication: c = a * b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> {
let description = [{ let description = [{
Element-wise max: c = max(a, b) Element-wise max: c = max(a, b)
}]; }];
@@ -291,6 +343,32 @@ def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterfac
let results = (outs let results = (outs
PimTensor: $outRes PimTensor: $outRes
); );
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Dot product: c = dot(a, b)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
} }
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> { def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
@@ -332,14 +410,13 @@ def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>
); );
} }
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> { def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{ let description = [{
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience) Average all elements into a single one
}]; }];
let arguments = (ins let arguments = (ins
PimTensor: $dividend, PimTensor: $a,
PimTensor: $divisor,
PimTensor: $outBuf PimTensor: $outBuf
); );
@@ -363,9 +440,24 @@ def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterf
); );
} }
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> { def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{ let description = [{
Element-wise exp: c = exp(a) Element-wise tanh activation
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise sigmoid activation
}]; }];
let arguments = (ins let arguments = (ins

View File

@@ -30,12 +30,13 @@ void PimDialect::initialize() {
registerDependenciesFn(this->getOutBuf(), this->getResult()); \ registerDependenciesFn(this->getOutBuf(), this->getResult()); \
} }
POPULATE_DEPENDENCIES(PimVMaxOp) POPULATE_DEPENDENCIES(PimVVDMulOp)
POPULATE_DEPENDENCIES(PimApplyFiltersOp) POPULATE_DEPENDENCIES(PimApplyFiltersOp)
POPULATE_DEPENDENCIES(PimSumOp) POPULATE_DEPENDENCIES(PimSumOp)
POPULATE_DEPENDENCIES(PimVSDivOp) POPULATE_DEPENDENCIES(PimVAvgOp)
POPULATE_DEPENDENCIES(PimVReluOp) POPULATE_DEPENDENCIES(PimVReluOp)
POPULATE_DEPENDENCIES(PimVExpOp) POPULATE_DEPENDENCIES(PimVTanhOp)
POPULATE_DEPENDENCIES(PimVSigmOp)
} // namespace pim } // namespace pim
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -4,6 +4,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "OpBufferizationInterfaces.hpp" #include "OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
@@ -12,6 +13,26 @@ using namespace bufferization;
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace pim {
static Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
return PimMemCopyOp::create(rewriter,
loc,
contiguousType,
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getDstOut();
}
struct MemCopyHostToDevOpInterface struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> { : DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op, LogicalResult bufferize(Operation* op,
@@ -164,7 +185,8 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
} }
}; };
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> { template <typename OpTy>
struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
} }
@@ -179,21 +201,24 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
RewriterBase& rewriter, RewriterBase& rewriter,
const BufferizationOptions& options, const BufferizationOptions& options,
BufferizationState& state) const { BufferizationState& state) const {
auto vaddOp = cast<PimVAddOp>(op); auto binaryOp = cast<OpTy>(op);
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state); auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state);
if (failed(aOpt)) if (failed(aOpt))
return failure(); return failure();
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state); auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state);
if (failed(bOpt)) if (failed(bOpt))
return failure(); return failure();
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state); auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state);
if (failed(outBufOpt)) if (failed(outBufOpt))
return failure(); return failure();
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt); Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter);
Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt);
return success(); return success();
} }
}; };
@@ -205,7 +230,10 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx); PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx); PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx); PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx); PimVVAddOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVAddOp>>(*ctx);
PimVVSubOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVSubOp>>(*ctx);
PimVVMulOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMulOp>>(*ctx);
PimVVMaxOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMaxOp>>(*ctx);
}); });
} }

View File

@@ -17,6 +17,7 @@
#include <cstdint> #include <cstdint>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
@@ -36,6 +37,25 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
return memref::AllocOp::create(rewriter, loc, memrefResultType); return memref::AllocOp::create(rewriter, loc, memrefResultType);
} }
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
return pim::PimMemCopyOp::create(rewriter,
loc,
contiguousBuffer.getType(),
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getDstOut();
}
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id"); const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) { static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
@@ -167,7 +187,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
auto memref = getBuffer(rewriter, operand, options, state); auto memref = getBuffer(rewriter, operand, options, state);
if (failed(memref)) if (failed(memref))
return failure(); return failure();
memrefOperands.push_back(*memref); memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
} }
// TODO: Support addiction with more than 2 operands // TODO: Support addiction with more than 2 operands
@@ -460,7 +480,7 @@ struct ChannelBroadcastSendOpInterface
}; };
struct VAddOpInterfaceFromTemplate struct VAddOpInterfaceFromTemplate
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVAddOp> {}; : VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {}; struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
@@ -468,9 +488,7 @@ struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, Spa
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {}; struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface<VSDivOpInterface, SpatVSDivOp, pim::PimVSDivOp> {}; struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVMaxOp> {};
// Create a new bufferizable op interface for the apply filters operation. // Create a new bufferizable op interface for the apply filters operation.
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> { struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
@@ -557,7 +575,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx); SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx); SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
SpatSumOp::attachInterface<SumOpInterface>(*ctx); SpatSumOp::attachInterface<SumOpInterface>(*ctx);
SpatVSDivOp::attachInterface<VSDivOpInterface>(*ctx);
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx); SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx); SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx); SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
@@ -569,12 +586,16 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {}; struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface<ONNXExpOpInterface, ONNXExpOp, pim::PimVExpOp> {}; struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
struct ONNXSigmoidInterface
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) { registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx); ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
ONNXExpOp::attachInterface<ONNXExpOpInterface>(*ctx); ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
}); });
} }

View File

@@ -5,7 +5,8 @@ add_pim_library(OMPimPasses
PimConstantFolding/Patterns/Constant.cpp PimConstantFolding/Patterns/Constant.cpp
PimConstantFolding/PimConstantFoldingPass.cpp PimConstantFolding/PimConstantFoldingPass.cpp
PimConstantFolding/Patterns/Subview.cpp PimConstantFolding/Patterns/Subview.cpp
PimHostVerificationPass.cpp PimMaterializeConstantsPass.cpp
PimVerificationPass.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -120,20 +120,8 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
rewriter.setInsertionPoint(coreOp); rewriter.setInsertionPoint(coreOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
size_t elementByteWidth = initType.getElementTypeBitWidth() / 8;
if (elementByteWidth == 0)
return failure();
size_t totalBytes = initType.getNumElements() * elementByteWidth;
rewriter.setInsertionPoint(mapOp); rewriter.setInsertionPoint(mapOp);
pim::PimMemCopyHostToDevOp::create(rewriter, rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
mapOp.getLoc(),
initType,
mapOp.getInit(),
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
rewriter.eraseOp(mapOp); rewriter.eraseOp(mapOp);
return success(); return success();
} }

View File

@@ -0,0 +1,135 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0;
return false;
}
static int64_t getValueSizeInBytes(Value value) {
auto type = dyn_cast<ShapedType>(value.getType());
if (!type || !type.hasStaticShape())
return -1;
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
struct PimMaterializeConstantsPass
: PassWrapper<PimMaterializeConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass)
StringRef getArgument() const override { return "materialize-pim-constants"; }
StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
OpBuilder rewriter(moduleOp.getContext());
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) {
if (isa<pim::PimHaltOp>(op))
continue;
for (OpOperand& operand : op.getOpOperands()) {
Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber()))
continue;
auto resolvedAddress = resolveContiguousAddress(originalValue);
if (failed(resolvedAddress))
continue;
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
if (!getGlobalOp)
continue;
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
if (!originalType || !originalType.hasStaticShape()) {
op.emitOpError("host constant materialization requires a static memref operand");
hasFailure = true;
continue;
}
auto& cachedByOffset = materializedValues[resolvedAddress->base];
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
auto cachedValue = cachedByType.find(originalType);
if (cachedValue != cachedByType.end()) {
operand.set(cachedValue->second);
continue;
}
int64_t totalBytes = getValueSizeInBytes(originalValue);
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
hasFailure = true;
continue;
}
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
rewriter.setInsertionPoint(&op);
Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType);
Value deviceDst = localAlloc;
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
op.getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(totalBytes)));
cachedByType[originalType] = hostToDevCopy.getResult();
operand.set(hostToDevCopy.getResult());
}
}
}
}
if (hasFailure) {
signalPassFailure();
return;
}
dumpModule(moduleOp, "pim3_materialized");
}
};
} // namespace
std::unique_ptr<Pass> createPimMaterializeConstantsPass() {
return std::make_unique<PimMaterializeConstantsPass>();
}
} // namespace onnx_mlir

View File

@@ -17,7 +17,9 @@ std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass(); std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimHostVerificationPass(); std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass();
std::unique_ptr<mlir::Pass> createEmitPimJsonPass(); std::unique_ptr<mlir::Pass> createEmitPimJsonPass();

View File

@@ -35,16 +35,24 @@ static bool isCodegenAddressableValue(Value value) {
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp()); || isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
} }
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> { static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass) if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0;
return false;
}
StringRef getArgument() const override { return "verify-pim-host-pass"; } struct PimVerificationPass : PassWrapper<PimVerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass)
StringRef getArgument() const override { return "verify-pim-pass"; }
StringRef getDescription() const override { StringRef getDescription() const override {
return "Verify that no runtime host-side code remains in bufferized PIM IR"; return "Verify that bufferized PIM IR contains only explicit host/device transfers";
} }
PimHostVerificationPass() {} PimVerificationPass() {}
PimHostVerificationPass(const PimHostVerificationPass& pass) {} PimVerificationPass(const PimVerificationPass& pass) {}
void runOnOperation() override { void runOnOperation() override {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
@@ -132,11 +140,27 @@ private:
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) { for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<BaseMemRefType>(operand.getType())) if (!isa<BaseMemRefType>(operand.getType()))
continue; continue;
if (succeeded(resolveContiguousAddress(operand)))
continue;
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage"; auto resolvedAddress = resolveContiguousAddress(operand);
hasFailure = true; if (failed(resolvedAddress)) {
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
continue;
}
if (isExplicitHostOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand)) {
op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
}
continue;
}
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
hasFailure = true;
}
} }
} }
return success(!hasFailure); return success(!hasFailure);
@@ -165,6 +189,6 @@ private:
} // namespace } // namespace
std::unique_ptr<Pass> createPimHostVerificationPass() { return std::make_unique<PimHostVerificationPass>(); } std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<PimVerificationPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -74,7 +74,8 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToPimPass); registerPass(createSpatialToPimPass);
registerPass(createBufferizePimPass); registerPass(createBufferizePimPass);
registerPass(createPimConstantFoldingPass); registerPass(createPimConstantFoldingPass);
registerPass(createPimHostVerificationPass); registerPass(createPimMaterializeConstantsPass);
registerPass(createPimVerificationPass);
registerPass(createEmitPimJsonPass); registerPass(createEmitPimJsonPass);
} }

View File

@@ -23,6 +23,18 @@ python3 validation/operations/gen_tests.py
| With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias | | With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias |
| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input | | Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input |
## Pool
| Test | Directory | Input | Output | Kernel | Stride | Padding | Notes |
|------|-----------|-------|--------|--------|--------|---------|-------|
| Max basic | `pool/max_basic` | [1,1,4,4] | [1,1,3,3] | 2x2 | 1 | none | Basic max pooling |
| Max stride 2 multi-channel | `pool/max_stride2_multichannel` | [1,5,6,6] | [1,5,3,3] | 2x2 | 2 | none | Channel-preserving max pool |
| Max SAME_UPPER | `pool/max_same_upper` | [1,1,5,5] | [1,1,3,3] | 3x3 | 2 | SAME_UPPER | Deprecated auto_pad path |
| Avg basic | `pool/avg_basic` | [1,3,4,4] | [1,3,3,3] | 2x2 | 1 | none | Basic average pooling |
| Avg explicit padding | `pool/avg_explicit_padding` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=0` |
| Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` |
| Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` |
## Gemm ## Gemm
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes | | Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |

View File

@@ -1,5 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Generate ONNX test models for validating GEMM and Conv implementations.""" """Generate ONNX test models for validating GEMM, Conv, and Pooling implementations."""
import numpy as np import numpy as np
import onnx import onnx
@@ -248,6 +248,85 @@ def conv_large_spatial():
save_model(model, "conv/large_spatial", "conv_large_spatial.onnx") save_model(model, "conv/large_spatial", "conv_large_spatial.onnx")
# ---------------------------------------------------------------------------
# Pooling tests
# ---------------------------------------------------------------------------
def maxpool_basic():
"""MaxPool 2x2 with stride 1."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "maxpool_basic", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/max_basic", "maxpool_basic.onnx")
def maxpool_stride2_multichannel():
"""MaxPool 2x2 with stride 2 on multiple channels."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 5, 6, 6])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 5, 3, 3])
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "maxpool_stride2_multichannel", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/max_stride2_multichannel", "maxpool_stride2_multichannel.onnx")
def maxpool_same_upper():
"""MaxPool 3x3 with SAME_UPPER padding."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[3, 3], strides=[2, 2], auto_pad="SAME_UPPER")
graph = helper.make_graph([node], "maxpool_same_upper", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/max_same_upper", "maxpool_same_upper.onnx")
def avgpool_basic():
"""AveragePool 2x2 with stride 1."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 3, 3])
node = helper.make_node("AveragePool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "avgpool_basic", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/avg_basic", "avgpool_basic.onnx")
def avgpool_explicit_padding():
"""AveragePool 3x3 with explicit padding, excluding pad from the divisor."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2])
node = helper.make_node("AveragePool", ["X"], ["Y"],
kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=0)
graph = helper.make_graph([node], "avgpool_explicit_padding", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/avg_explicit_padding", "avgpool_explicit_padding.onnx")
def avgpool_include_pad():
"""AveragePool 3x3 with explicit padding, including pad in the divisor."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2])
node = helper.make_node("AveragePool", ["X"], ["Y"],
kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=1)
graph = helper.make_graph([node], "avgpool_include_pad", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/avg_include_pad", "avgpool_include_pad.onnx")
def maxpool_after_conv():
"""Conv followed by MaxPool to validate pooling on lowered conv results."""
rng = np.random.default_rng(59)
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 6, 6])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 2, 2])
W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
conv = helper.make_node("Conv", ["X", "W"], ["C"], kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
pool = helper.make_node("MaxPool", ["C"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0])
graph = helper.make_graph([conv, pool], "maxpool_after_conv", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Main # Main
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -273,4 +352,13 @@ if __name__ == "__main__":
conv_batch_2() conv_batch_2()
conv_large_spatial() conv_large_spatial()
print("\nGenerating Pooling tests:")
maxpool_basic()
maxpool_stride2_multichannel()
maxpool_same_upper()
avgpool_basic()
avgpool_explicit_padding()
avgpool_include_pad()
maxpool_after_conv()
print("\nDone.") print("\nDone.")