reimplement pool lowering
add pool validation align PIM ops/codegen/parser with the ISA move constant materialization to MLIR rename the PIM verification/materialization passes better folded-constant handling
This commit is contained in:
@@ -530,6 +530,7 @@ where
|
|||||||
let r2_val = r2;
|
let r2_val = r2;
|
||||||
ensure!(r2_val == 1, "Stride different than 1 not supported");
|
ensure!(r2_val == 1, "Stride different than 1 not supported");
|
||||||
let rd_val = core.register(rd);
|
let rd_val = core.register(rd);
|
||||||
|
ensure!(offset_select == 1, "Offset select cannot be different from 1");
|
||||||
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
||||||
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
|
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
|
||||||
let load1 = loads[0];
|
let load1 = loads[0];
|
||||||
|
|||||||
@@ -224,7 +224,21 @@ fn json_to_vvsub(
|
|||||||
inst_data_builder: &mut InstructionDataBuilder,
|
inst_data_builder: &mut InstructionDataBuilder,
|
||||||
json: &Value,
|
json: &Value,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
todo!("Not present in the compiler");
|
let json = json.as_object().expect("Not an object");
|
||||||
|
assert_eq!("vvsub", json_str!(json, "op"));
|
||||||
|
let rd = json_i64!(json, "rd") as i32;
|
||||||
|
let rs1 = json_i64!(json, "rs1") as i32;
|
||||||
|
let rs2 = json_i64!(json, "rs2") as i32;
|
||||||
|
let len = json_i64!(json, "len") as i32;
|
||||||
|
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||||
|
inst_data_builder
|
||||||
|
.set_rd(rd)
|
||||||
|
.set_r1(rs1)
|
||||||
|
.set_r2(rs2)
|
||||||
|
.set_imm_len(len)
|
||||||
|
.set_offset_select(offset_select)
|
||||||
|
.set_offset_value(offset_value);
|
||||||
|
inst_builder.make_inst(vvsub, inst_data_builder.build());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,7 +270,21 @@ fn json_to_vvdmul(
|
|||||||
inst_data_builder: &mut InstructionDataBuilder,
|
inst_data_builder: &mut InstructionDataBuilder,
|
||||||
json: &Value,
|
json: &Value,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
todo!("Not present in the compiler");
|
let json = json.as_object().expect("Not an object");
|
||||||
|
assert_eq!("vvdmul", json_str!(json, "op"));
|
||||||
|
let rd = json_i64!(json, "rd") as i32;
|
||||||
|
let rs1 = json_i64!(json, "rs1") as i32;
|
||||||
|
let rs2 = json_i64!(json, "rs2") as i32;
|
||||||
|
let len = json_i64!(json, "len") as i32;
|
||||||
|
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||||
|
inst_data_builder
|
||||||
|
.set_rd(rd)
|
||||||
|
.set_r1(rs1)
|
||||||
|
.set_r2(rs2)
|
||||||
|
.set_imm_len(len)
|
||||||
|
.set_offset_select(offset_select)
|
||||||
|
.set_offset_value(offset_value);
|
||||||
|
inst_builder.make_inst(vvdmul, inst_data_builder.build());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,7 +334,21 @@ fn json_to_vavg(
|
|||||||
inst_data_builder: &mut InstructionDataBuilder,
|
inst_data_builder: &mut InstructionDataBuilder,
|
||||||
json: &Value,
|
json: &Value,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
todo!("Not present in the compiler");
|
let json = json.as_object().expect("Not an object");
|
||||||
|
assert_eq!("vavg", json_str!(json, "op"));
|
||||||
|
let rd = json_i64!(json, "rd") as i32;
|
||||||
|
let rs1 = json_i64!(json, "rs1") as i32;
|
||||||
|
let rs2 = json_i64!(json, "rs2") as i32;
|
||||||
|
let len = json_i64!(json, "len") as i32;
|
||||||
|
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||||
|
inst_data_builder
|
||||||
|
.set_rd(rd)
|
||||||
|
.set_r1(rs1)
|
||||||
|
.set_r2(rs2)
|
||||||
|
.set_imm_len(len)
|
||||||
|
.set_offset_select(offset_select)
|
||||||
|
.set_offset_value(offset_value);
|
||||||
|
inst_builder.make_inst(vavg, inst_data_builder.build());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -358,7 +400,7 @@ fn json_to_vsigm(
|
|||||||
json: &Value,
|
json: &Value,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let json = json.as_object().expect("Not an object");
|
let json = json.as_object().expect("Not an object");
|
||||||
assert_eq!("vsigmoid", json_str!(json, "op"));
|
assert_eq!("vsigm", json_str!(json, "op"));
|
||||||
let rd = json_i64!(json, "rd") as i32;
|
let rd = json_i64!(json, "rd") as i32;
|
||||||
let rs1 = json_i64!(json, "rs1") as i32;
|
let rs1 = json_i64!(json, "rs1") as i32;
|
||||||
let len = json_i64!(json, "len") as i32;
|
let len = json_i64!(json, "len") as i32;
|
||||||
|
|||||||
@@ -237,14 +237,16 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeM
|
|||||||
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
|
static size_t getValueSizeInBytes(mlir::Value value) {
|
||||||
auto outBufAddr = memory.getValueAddress(vaddOp.getOutBuf());
|
auto type = cast<ShapedType>(value.getType());
|
||||||
auto aAddr = memory.getValueAddress(vaddOp.getA());
|
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||||
auto bAddr = memory.getValueAddress(vaddOp.getB());
|
}
|
||||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
|
||||||
|
|
||||||
auto outputType = cast<MemRefType>(vaddOp.getOutBuf().getType());
|
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
|
||||||
size_t totalBytes = outputType.getNumElements() * vaddOp.getOutRes().getType().getElementTypeBitWidth() / 8;
|
auto outBufAddr = memory.getValueAddress(vvaddOp.getOutBuf());
|
||||||
|
auto aAddr = memory.getValueAddress(vvaddOp.getA());
|
||||||
|
auto bAddr = memory.getValueAddress(vvaddOp.getB());
|
||||||
|
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vvadd";
|
json["op"] = "vvadd";
|
||||||
@@ -252,14 +254,46 @@ void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
|
|||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["rs2"] = 2;
|
json["rs2"] = 2;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
json["len"] = totalBytes;
|
json["len"] = getValueSizeInBytes(vvaddOp.getA());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
|
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(vmaxOp.getOutBuf());
|
auto outBufAddr = memory.getValueAddress(vvsubOp.getOutBuf());
|
||||||
auto aAddr = memory.getValueAddress(vmaxOp.getA());
|
auto aAddr = memory.getValueAddress(vvsubOp.getA());
|
||||||
auto bAddr = memory.getValueAddress(vmaxOp.getB());
|
auto bAddr = memory.getValueAddress(vvsubOp.getB());
|
||||||
|
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||||
|
|
||||||
|
json::Object json;
|
||||||
|
json["op"] = "vvsub";
|
||||||
|
json["rd"] = 0;
|
||||||
|
json["rs1"] = 1;
|
||||||
|
json["rs2"] = 2;
|
||||||
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vvsubOp.getA());
|
||||||
|
emitInstruction(std::move(json));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
|
||||||
|
auto outBufAddr = memory.getValueAddress(vvmulOp.getOutBuf());
|
||||||
|
auto aAddr = memory.getValueAddress(vvmulOp.getA());
|
||||||
|
auto bAddr = memory.getValueAddress(vvmulOp.getB());
|
||||||
|
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||||
|
|
||||||
|
json::Object json;
|
||||||
|
json["op"] = "vvmul";
|
||||||
|
json["rd"] = 0;
|
||||||
|
json["rs1"] = 1;
|
||||||
|
json["rs2"] = 2;
|
||||||
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vvmulOp.getA());
|
||||||
|
emitInstruction(std::move(json));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
|
||||||
|
auto outBufAddr = memory.getValueAddress(vvmaxOp.getOutBuf());
|
||||||
|
auto aAddr = memory.getValueAddress(vvmaxOp.getA());
|
||||||
|
auto bAddr = memory.getValueAddress(vvmaxOp.getB());
|
||||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
@@ -268,6 +302,37 @@ void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
|
|||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["rs2"] = 2;
|
json["rs2"] = 2;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vvmaxOp.getA());
|
||||||
|
emitInstruction(std::move(json));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
|
||||||
|
auto outBufAddr = memory.getValueAddress(vvdmulOp.getOutBuf());
|
||||||
|
auto aAddr = memory.getValueAddress(vvdmulOp.getA());
|
||||||
|
auto bAddr = memory.getValueAddress(vvdmulOp.getB());
|
||||||
|
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||||
|
|
||||||
|
json::Object json;
|
||||||
|
json["op"] = "vvdmul";
|
||||||
|
json["rd"] = 0;
|
||||||
|
json["rs1"] = 1;
|
||||||
|
json["rs2"] = 2;
|
||||||
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vvdmulOp.getA());
|
||||||
|
emitInstruction(std::move(json));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
|
||||||
|
auto outBufAddr = memory.getValueAddress(vavgOp.getOutBuf());
|
||||||
|
auto aAddr = memory.getValueAddress(vavgOp.getA());
|
||||||
|
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
||||||
|
|
||||||
|
json::Object json;
|
||||||
|
json["op"] = "vavg";
|
||||||
|
json["rd"] = 0;
|
||||||
|
json["rs1"] = 1;
|
||||||
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vavgOp.getA());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,6 +346,35 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
|
|||||||
json["rd"] = 0;
|
json["rd"] = 0;
|
||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vreluOp.getA());
|
||||||
|
emitInstruction(std::move(json));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
|
||||||
|
auto outBufAddr = memory.getValueAddress(vtanhOp.getOutBuf());
|
||||||
|
auto aAddr = memory.getValueAddress(vtanhOp.getA());
|
||||||
|
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
||||||
|
|
||||||
|
json::Object json;
|
||||||
|
json["op"] = "vtanh";
|
||||||
|
json["rd"] = 0;
|
||||||
|
json["rs1"] = 1;
|
||||||
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vtanhOp.getA());
|
||||||
|
emitInstruction(std::move(json));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
|
||||||
|
auto outBufAddr = memory.getValueAddress(vsigmOp.getOutBuf());
|
||||||
|
auto aAddr = memory.getValueAddress(vsigmOp.getA());
|
||||||
|
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
||||||
|
|
||||||
|
json::Object json;
|
||||||
|
json["op"] = "vsigm";
|
||||||
|
json["rd"] = 0;
|
||||||
|
json["rs1"] = 1;
|
||||||
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vsigmOp.getA());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -338,6 +432,7 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co
|
|||||||
vaddJson["rs1"] = 1;
|
vaddJson["rs1"] = 1;
|
||||||
vaddJson["rs2"] = 2;
|
vaddJson["rs2"] = 2;
|
||||||
vaddJson["offset"] = createEmptyOffset();
|
vaddJson["offset"] = createEmptyOffset();
|
||||||
|
vaddJson["len"] = 32 * outChannels;
|
||||||
emitInstruction(std::move(vaddJson));
|
emitInstruction(std::move(vaddJson));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -479,13 +574,25 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
|
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
|
||||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||||
coreCodeGen.codeGenTransposeOp(transposeOp);
|
coreCodeGen.codeGenTransposeOp(transposeOp);
|
||||||
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op))
|
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||||
coreCodeGen.codeGenVAddOp(vaddOp);
|
coreCodeGen.codeGenVVAddOp(vvaddOp);
|
||||||
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op))
|
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
|
||||||
coreCodeGen.codeGenVMaxOp(vmaxOp);
|
coreCodeGen.codeGenVVSubOp(vvsubOp);
|
||||||
|
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
|
||||||
|
coreCodeGen.codeGenVVMulOp(vvmulOp);
|
||||||
|
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
|
||||||
|
coreCodeGen.codeGenVVMaxOp(vvmaxOp);
|
||||||
|
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
|
||||||
|
coreCodeGen.codeGenVVDMulOp(vvdmulOp);
|
||||||
|
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
|
||||||
|
coreCodeGen.codeGenVAvgOp(vavgOp);
|
||||||
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
|
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
|
||||||
coreCodeGen.codeGenVReluOp(vreluOp);
|
coreCodeGen.codeGenVReluOp(vreluOp);
|
||||||
else if (isa<pim::PimSumOp, pim::PimVSDivOp, pim::PimVExpOp>(op)) {
|
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
|
||||||
|
coreCodeGen.codeGenVTanhOp(vtanhOp);
|
||||||
|
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
|
||||||
|
coreCodeGen.codeGenVSigmOp(vsigmOp);
|
||||||
|
else if (isa<pim::PimSumOp>(op)) {
|
||||||
// TODO: Implement somehow?
|
// TODO: Implement somehow?
|
||||||
op.emitWarning("Operation is not yet supported in code generation");
|
op.emitWarning("Operation is not yet supported in code generation");
|
||||||
continue;
|
continue;
|
||||||
|
|||||||
@@ -90,9 +90,15 @@ public:
|
|||||||
template <typename MVMTy>
|
template <typename MVMTy>
|
||||||
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
|
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
|
||||||
|
|
||||||
void codeGenVAddOp(pim::PimVAddOp vaddOp) const;
|
void codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const;
|
||||||
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const;
|
void codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const;
|
||||||
|
void codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const;
|
||||||
|
void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const;
|
||||||
|
void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const;
|
||||||
|
void codeGenVAvgOp(pim::PimVAvgOp vavgOp) const;
|
||||||
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
|
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
|
||||||
|
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
|
||||||
|
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
|
||||||
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
|
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
|
||||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
|
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -47,8 +47,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||||
pm.addPass(createPimConstantFoldingPass());
|
pm.addPass(createPimConstantFoldingPass());
|
||||||
pm.addPass(createMessagePass("Pim constants folded"));
|
pm.addPass(createMessagePass("Pim constants folded"));
|
||||||
pm.addPass(createPimHostVerificationPass());
|
pm.addPass(createPimMaterializeConstantsPass());
|
||||||
pm.addPass(createMessagePass("Pim host verified"));
|
pm.addPass(createPimVerificationPass());
|
||||||
|
pm.addPass(createMessagePass("Pim verified"));
|
||||||
pm.addPass(createEmitPimJsonPass());
|
pm.addPass(createEmitPimJsonPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("Pim json code emitted"));
|
pm.addPass(createMessagePass("Pim json code emitted"));
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ add_pim_library(OMONNXToSpatial
|
|||||||
Patterns/Math/Gemm.cpp
|
Patterns/Math/Gemm.cpp
|
||||||
Patterns/Math/Conv.cpp
|
Patterns/Math/Conv.cpp
|
||||||
Patterns/Math/MatMul.cpp
|
Patterns/Math/MatMul.cpp
|
||||||
Patterns/NN/Pooling.cpp
|
Patterns/NN/Pool.cpp
|
||||||
Patterns/NN/ReduceMean.cpp
|
Patterns/NN/ReduceMean.cpp
|
||||||
Patterns/Tensor/Concat.cpp
|
Patterns/Tensor/Concat.cpp
|
||||||
Patterns/Tensor/Reshape.cpp
|
Patterns/Tensor/Reshape.cpp
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
patterns.add<removeLRNPattern>(ctx);
|
patterns.add<removeLRNPattern>(ctx);
|
||||||
|
|
||||||
populateConvOpPatterns(patterns, ctx);
|
populateConvOpPatterns(patterns, ctx);
|
||||||
populatePoolingTilingPattern(patterns, ctx);
|
populatePoolTilingPattern(patterns, ctx);
|
||||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||||
populateReshapeConversionPattern(patterns, ctx);
|
populateReshapeConversionPattern(patterns, ctx);
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIR
|
|||||||
|
|
||||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
|||||||
265
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp
Normal file
265
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <optional>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename ArrayAttrT>
|
||||||
|
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
|
||||||
|
return cast<IntegerAttr>(arrayAttr[index]).getInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ArrayAttrT>
|
||||||
|
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
|
||||||
|
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
||||||
|
assert(!values.empty() && "Expected at least one value to concatenate.");
|
||||||
|
if (values.size() == 1)
|
||||||
|
return values.front();
|
||||||
|
return tensor::ConcatOp::create(rewriter, loc, axis, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||||
|
auto tileType = cast<RankedTensorType>(tile.getType());
|
||||||
|
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
sizes.reserve(tileType.getRank());
|
||||||
|
for (int64_t dimSize : tileType.getShape())
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(dimSize));
|
||||||
|
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
|
||||||
|
|
||||||
|
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ReduceOp>
|
||||||
|
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
|
||||||
|
assert(!windowValues.empty() && "Expected at least one pool window value.");
|
||||||
|
|
||||||
|
Value reduced = windowValues.front();
|
||||||
|
for (Value value : windowValues.drop_front())
|
||||||
|
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
|
||||||
|
return reduced;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value
|
||||||
|
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
|
||||||
|
assert(divisor > 0 && "AveragePool divisor must be positive.");
|
||||||
|
if (divisor == 1)
|
||||||
|
return reducedWindow;
|
||||||
|
|
||||||
|
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
|
||||||
|
double scale = 1.0 / static_cast<double>(divisor);
|
||||||
|
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
||||||
|
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
||||||
|
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PoolOp>
|
||||||
|
struct PoolToSpatialCompute;
|
||||||
|
|
||||||
|
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||||
|
struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||||
|
using OpConversionPattern<PoolOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||||
|
Location loc = poolOp.getLoc();
|
||||||
|
Value x = adaptor.getX();
|
||||||
|
|
||||||
|
auto xType = dyn_cast<RankedTensorType>(x.getType());
|
||||||
|
auto outType = dyn_cast<RankedTensorType>(poolOp.getResult().getType());
|
||||||
|
if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types.");
|
||||||
|
if (xType.getRank() != 4 || outType.getRank() != 4)
|
||||||
|
return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported.");
|
||||||
|
|
||||||
|
ArrayAttr kernelAttr = poolOp.getKernelShape();
|
||||||
|
if (!kernelAttr || kernelAttr.size() != 2)
|
||||||
|
return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel.");
|
||||||
|
|
||||||
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
|
const int64_t channels = xType.getDimSize(1);
|
||||||
|
const int64_t inputHeight = xType.getDimSize(2);
|
||||||
|
const int64_t inputWidth = xType.getDimSize(3);
|
||||||
|
const int64_t outputHeight = outType.getDimSize(2);
|
||||||
|
const int64_t outputWidth = outType.getDimSize(3);
|
||||||
|
const int64_t kernelHeight = getI64(kernelAttr, 0);
|
||||||
|
const int64_t kernelWidth = getI64(kernelAttr, 1);
|
||||||
|
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
|
||||||
|
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
|
||||||
|
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
|
||||||
|
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
|
||||||
|
|
||||||
|
int64_t padTop = 0;
|
||||||
|
int64_t padLeft = 0;
|
||||||
|
int64_t padBottom = 0;
|
||||||
|
int64_t padRight = 0;
|
||||||
|
|
||||||
|
if (auto padsAttr = poolOp.getPads()) {
|
||||||
|
if (padsAttr->size() != 4)
|
||||||
|
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
|
||||||
|
padTop = getI64(*padsAttr, 0);
|
||||||
|
padLeft = getI64(*padsAttr, 1);
|
||||||
|
padBottom = getI64(*padsAttr, 2);
|
||||||
|
padRight = getI64(*padsAttr, 3);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
StringRef autoPad = poolOp.getAutoPad();
|
||||||
|
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||||
|
const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1;
|
||||||
|
const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1;
|
||||||
|
const int64_t totalPadH =
|
||||||
|
std::max<int64_t>(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight);
|
||||||
|
const int64_t totalPadW = std::max<int64_t>(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth);
|
||||||
|
|
||||||
|
if (autoPad == "SAME_UPPER") {
|
||||||
|
padTop = totalPadH / 2;
|
||||||
|
padBottom = totalPadH - padTop;
|
||||||
|
padLeft = totalPadW / 2;
|
||||||
|
padRight = totalPadW - padLeft;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
padBottom = totalPadH / 2;
|
||||||
|
padTop = totalPadH - padBottom;
|
||||||
|
padRight = totalPadW / 2;
|
||||||
|
padLeft = totalPadW - padRight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||||
|
return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(void) padBottom;
|
||||||
|
(void) padRight;
|
||||||
|
|
||||||
|
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||||
|
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
||||||
|
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x});
|
||||||
|
|
||||||
|
auto* computeBlock = new Block();
|
||||||
|
computeBlock->addArgument(xType, loc);
|
||||||
|
computeOp.getBody().push_back(computeBlock);
|
||||||
|
rewriter.setInsertionPointToStart(computeBlock);
|
||||||
|
|
||||||
|
Value input = computeBlock->getArgument(0);
|
||||||
|
SmallVector<Value> batchResults;
|
||||||
|
batchResults.reserve(batchSize);
|
||||||
|
|
||||||
|
for (int64_t batch = 0; batch < batchSize; ++batch) {
|
||||||
|
SmallVector<Value> rows;
|
||||||
|
rows.reserve(outputHeight);
|
||||||
|
|
||||||
|
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
||||||
|
SmallVector<Value> rowPixels;
|
||||||
|
rowPixels.reserve(outputWidth);
|
||||||
|
|
||||||
|
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
||||||
|
SmallVector<Value> outputChannelTiles;
|
||||||
|
outputChannelTiles.reserve(channelTileCount);
|
||||||
|
|
||||||
|
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||||
|
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||||
|
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||||
|
|
||||||
|
SmallVector<Value> windowValues;
|
||||||
|
windowValues.reserve(kernelHeight * kernelWidth);
|
||||||
|
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||||
|
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
||||||
|
if (inH < 0 || inH >= inputHeight)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||||
|
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
||||||
|
if (inW < 0 || inW >= inputWidth)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
|
||||||
|
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||||
|
rewriter.getIndexAttr(inH),
|
||||||
|
rewriter.getIndexAttr(inW)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(tileChannels),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1)};
|
||||||
|
Value windowValue =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides);
|
||||||
|
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||||
|
windowValues.push_back(windowValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (windowValues.empty())
|
||||||
|
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
||||||
|
|
||||||
|
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
||||||
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||||
|
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||||
|
const int64_t divisor =
|
||||||
|
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
||||||
|
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
||||||
|
}
|
||||||
|
|
||||||
|
outputChannelTiles.push_back(reducedWindow);
|
||||||
|
}
|
||||||
|
|
||||||
|
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
||||||
|
}
|
||||||
|
|
||||||
|
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
||||||
|
}
|
||||||
|
|
||||||
|
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
||||||
|
|
||||||
|
rewriter.replaceOp(poolOp, computeOp.getResult(0));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>
|
||||||
|
: public PoolToSpatialComputeBase<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp> {
|
||||||
|
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PoolToSpatialCompute<ONNXAveragePoolOp>
|
||||||
|
: public PoolToSpatialComputeBase<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp> {
|
||||||
|
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
|
||||||
|
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,427 +0,0 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "mlir/IR/ValueRange.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
#include "llvm/Support/Debug.h"
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <cmath>
|
|
||||||
#include <cstddef>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
std::function<Value(const Value&, const Value&)> reduce,
|
|
||||||
std::function<Value(const Value&)> preprocess,
|
|
||||||
std::function<Value(const Value&)> postprocess) {
|
|
||||||
// Simple case: if we have only one input, just return it
|
|
||||||
if (valuesToReduce.size() == 1)
|
|
||||||
return valuesToReduce[0];
|
|
||||||
|
|
||||||
if (preprocess) {
|
|
||||||
for (auto& valToReduce : valuesToReduce) {
|
|
||||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
|
||||||
valToReduce = preprocess(valToReduce);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// It is possible that `valuesToReduce` contains two entries for the same
|
|
||||||
// computeOp. In this case, we need to apply the reduction within-computef
|
|
||||||
|
|
||||||
// Keep a map between a computeOp and the last Value for this reduction
|
|
||||||
std::unordered_map<Operation*, Value> lastValueForCompute;
|
|
||||||
for (auto& valToReduce : valuesToReduce) {
|
|
||||||
Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
|
|
||||||
// if (valToReduce.getDefiningOp()) {
|
|
||||||
// // If the value is defined by an operation, we take the parent
|
|
||||||
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
|
|
||||||
// } else {
|
|
||||||
// // Otherwise it is a block argument,
|
|
||||||
// computeOp->getBlock()->getParentOp();
|
|
||||||
// }
|
|
||||||
|
|
||||||
assert(isa<spatial::SpatWeightedCompute>(computeOp) && "Expected a ComputeOp");
|
|
||||||
|
|
||||||
auto it = lastValueForCompute.find(computeOp);
|
|
||||||
|
|
||||||
if (it != lastValueForCompute.end()) {
|
|
||||||
// If we have already seen this computeOp, apply the reduction
|
|
||||||
// within-compute
|
|
||||||
Value lastWithinComputeValue = it->second;
|
|
||||||
|
|
||||||
if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
|
|
||||||
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
|
|
||||||
else
|
|
||||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
|
||||||
valToReduce = reduce(lastWithinComputeValue, valToReduce);
|
|
||||||
lastValueForCompute[computeOp] = valToReduce;
|
|
||||||
}
|
|
||||||
|
|
||||||
lastValueForCompute[computeOp] = valToReduce;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, reconstruct from the map the valuesToReduce list
|
|
||||||
valuesToReduce.clear();
|
|
||||||
valuesToReduce.reserve(lastValueForCompute.size());
|
|
||||||
for (auto& entry : lastValueForCompute)
|
|
||||||
valuesToReduce.push_back(entry.second);
|
|
||||||
|
|
||||||
Location loc = valuesToReduce[0].getLoc();
|
|
||||||
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
|
|
||||||
|
|
||||||
// Recursive algorithm to reduce the inputs to a single one:
|
|
||||||
// - Take two inputs at a time, and reduce them into a single one, updating
|
|
||||||
// the valuesToReduce list which becomes half the size.
|
|
||||||
// - Repeat until there is only one input left.
|
|
||||||
llvm::OwningArrayRef<Value> valuesToReduceRef(valuesToReduce);
|
|
||||||
while (valuesToReduceRef.size() > 1) {
|
|
||||||
SmallVector<Value> nextValuesToReduce;
|
|
||||||
nextValuesToReduce.reserve(valuesToReduceRef.size() / 2);
|
|
||||||
for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) {
|
|
||||||
auto firstValue = valuesToReduceRef[i];
|
|
||||||
auto secondValue = valuesToReduceRef[i + 1];
|
|
||||||
|
|
||||||
auto firstCompute = firstValue.getParentBlock()->getParentOp();
|
|
||||||
auto secondCompute = secondValue.getParentBlock()->getParentOp();
|
|
||||||
|
|
||||||
assert(isa<spatial::SpatWeightedCompute>(firstCompute));
|
|
||||||
assert(isa<spatial::SpatWeightedCompute>(secondCompute));
|
|
||||||
|
|
||||||
if (secondCompute->isBeforeInBlock(firstCompute)) {
|
|
||||||
std::swap(firstValue, secondValue);
|
|
||||||
std::swap(firstCompute, secondCompute);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1. Add a channel before the first computeOp
|
|
||||||
rewriter.setInsertionPoint(firstCompute);
|
|
||||||
auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
|
||||||
|
|
||||||
// 2. Add a sendOp after the first value
|
|
||||||
rewriter.setInsertionPointAfterValue(firstValue);
|
|
||||||
spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue);
|
|
||||||
|
|
||||||
// 3. Add a receiveOp after the second value
|
|
||||||
rewriter.setInsertionPointAfterValue(secondValue);
|
|
||||||
auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel);
|
|
||||||
|
|
||||||
// 4. Apply reduction between second value and received value
|
|
||||||
rewriter.setInsertionPointAfterValue(receivedValue);
|
|
||||||
Value reduced = reduce(receivedValue, secondValue);
|
|
||||||
|
|
||||||
nextValuesToReduce.push_back(reduced);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we have an odd number of inputs, we need to add the last one to the
|
|
||||||
// newInputs list.
|
|
||||||
if (valuesToReduceRef.size() % 2 == 1)
|
|
||||||
nextValuesToReduce.push_back(valuesToReduceRef.back());
|
|
||||||
|
|
||||||
// Replace the inputOps list with the new one.
|
|
||||||
valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
|
|
||||||
|
|
||||||
auto finalValue = valuesToReduceRef[0];
|
|
||||||
|
|
||||||
if (postprocess) {
|
|
||||||
rewriter.setInsertionPointAfterValue(finalValue);
|
|
||||||
finalValue = postprocess(finalValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
return finalValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PoolOp>
|
|
||||||
bool hasPostProcessPoolingWindow() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PoolOp>
|
|
||||||
Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc,
|
|
||||||
PoolOp poolOp,
|
|
||||||
Value valueToDivide,
|
|
||||||
size_t krn_size,
|
|
||||||
size_t tilesSkippedByPadding) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc,
|
|
||||||
ONNXAveragePoolOp poolOp,
|
|
||||||
Value valueToDivide,
|
|
||||||
size_t krn_size,
|
|
||||||
size_t tilesSkippedByPadding) {
|
|
||||||
bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
|
||||||
|
|
||||||
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
|
|
||||||
|
|
||||||
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
|
|
||||||
|
|
||||||
// Put a spat.const before the computeOp, and use its value. We do this to be
|
|
||||||
// compatible with the current code generation, which assumes constant to be
|
|
||||||
// loaded in global memory, which is allocated by adding a spat.const OP
|
|
||||||
// directly under func.func (i.e. alongside ComputeOps)
|
|
||||||
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
|
||||||
rewriter.setInsertionPoint(computeOp);
|
|
||||||
auto divisorValue = spatial::SpatConstantOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
scalarTensor,
|
|
||||||
rewriter.getI64IntegerAttr(divisorNumber),
|
|
||||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfterValue(valueToDivide);
|
|
||||||
return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
|
||||||
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
|
||||||
PoolingBaseConverter(MLIRContext* ctx)
|
|
||||||
: OpConversionPattern<PoolOp>(ctx) {}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
|
||||||
Value X = adaptor.getX();
|
|
||||||
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
|
|
||||||
Value Y = poolOp.getResult();
|
|
||||||
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
|
|
||||||
|
|
||||||
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
|
|
||||||
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
|
|
||||||
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
|
|
||||||
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
|
|
||||||
|
|
||||||
if (adaptor.getAutoPad() != "NOTSET")
|
|
||||||
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
|
|
||||||
|
|
||||||
size_t pad_x, pad_y;
|
|
||||||
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
|
|
||||||
if (padUnpackError.has_value())
|
|
||||||
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
|
|
||||||
|
|
||||||
Location loc = poolOp.getLoc();
|
|
||||||
|
|
||||||
size_t input_h = getImageHeight(xShape);
|
|
||||||
size_t input_w = getImageWidth(xShape);
|
|
||||||
size_t output_h = getImageHeight(yShape);
|
|
||||||
size_t output_w = getImageWidth(yShape);
|
|
||||||
size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
|
|
||||||
size_t channelTileRest = getImageChannel(xShape) % crossbarSize;
|
|
||||||
|
|
||||||
// 1: Tile the input tensor
|
|
||||||
// Input tiles need to be indexed by:
|
|
||||||
// a. Channel Tile
|
|
||||||
// b. Pixel `x` position
|
|
||||||
// c. Pixel `y` position
|
|
||||||
// For example: inputTiles[channelTile][x][y]
|
|
||||||
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
|
|
||||||
// Suppose that the input tensor is produced by concatenating the results of
|
|
||||||
// many ComputeOps. Get the result tiles from these ComputeOps.
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
|
|
||||||
channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
|
||||||
|
|
||||||
auto resolveErrorOpt =
|
|
||||||
resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
|
||||||
if (resolveErrorOpt.has_value())
|
|
||||||
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
|
|
||||||
|
|
||||||
// TODO: This requires a core for each input tile, which is not ideal. We
|
|
||||||
// can do better.
|
|
||||||
// If some input tiles come from the func.func operands, load
|
|
||||||
// them into a computeOp and yield them
|
|
||||||
for (size_t t = 0; t < channelTileCount; t++) {
|
|
||||||
for (size_t x = 0; x < input_w; x++) {
|
|
||||||
for (size_t y = 0; y < input_h; y++) {
|
|
||||||
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
|
|
||||||
Location tileLoc = extractSliceOp.getLoc();
|
|
||||||
|
|
||||||
auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter,
|
|
||||||
tileLoc,
|
|
||||||
extractSliceOp.getResultType(),
|
|
||||||
/* xbarWeights =*/ValueRange(),
|
|
||||||
extractSliceOp.getResult());
|
|
||||||
|
|
||||||
Block* tempComputeOpBlock = new Block();
|
|
||||||
tempComputeOp.getBody().push_back(tempComputeOpBlock);
|
|
||||||
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
|
|
||||||
|
|
||||||
rewriter.setInsertionPointToStart(tempComputeOpBlock);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg);
|
|
||||||
rewriter.setInsertionPointAfter(tempComputeOp);
|
|
||||||
inputTiles[t][x][y] = tempComputeOp.getResult(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2: Tile the output tensor
|
|
||||||
// Output tiles need to be indexed by:
|
|
||||||
// a. Channel Tile
|
|
||||||
// b. Pixel `x` position
|
|
||||||
// c. Pixel `y` position
|
|
||||||
// For example: outputTiles[channelTile][x][y]
|
|
||||||
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
|
|
||||||
channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
|
|
||||||
|
|
||||||
// List of values to pool for each output pixel
|
|
||||||
SmallVector<Value> valuesToPool;
|
|
||||||
|
|
||||||
// Iterate each output tile
|
|
||||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
|
||||||
// Iterate each output pixel
|
|
||||||
for (size_t outX = 0; outX < output_w; outX++) {
|
|
||||||
for (size_t outY = 0; outY < output_h; outY++) {
|
|
||||||
|
|
||||||
// Each output pixel tile is computed by pooling a window of input
|
|
||||||
// pixel tiles
|
|
||||||
valuesToPool.clear();
|
|
||||||
size_t tilesSkippedByPadding = 0;
|
|
||||||
|
|
||||||
auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
|
|
||||||
auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
|
|
||||||
|
|
||||||
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
|
|
||||||
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
|
|
||||||
if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
|
|
||||||
tilesSkippedByPadding++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
Value inputTile = inputTiles[outTile][inX][inY];
|
|
||||||
|
|
||||||
Value valueToPool;
|
|
||||||
if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
|
|
||||||
|
|
||||||
int resultNumber = getResultIndex(computeProducer, inputTile);
|
|
||||||
|
|
||||||
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
|
|
||||||
valueToPool = yieldInComputeOp.getOperand(resultNumber);
|
|
||||||
}
|
|
||||||
else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
|
|
||||||
auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
|
|
||||||
if (failed(sendOpOpt)) {
|
|
||||||
return rewriter.notifyMatchFailure(poolOp,
|
|
||||||
"ChannelReceiveOp does not have a matching "
|
|
||||||
"ChannelSendOp.");
|
|
||||||
}
|
|
||||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
|
||||||
|
|
||||||
valueToPool = sendOp.getData();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
return rewriter.notifyMatchFailure(poolOp,
|
|
||||||
"Input tile for Pooling is not produced by a "
|
|
||||||
"WeightedComputeOp nor a receiveOp");
|
|
||||||
}
|
|
||||||
|
|
||||||
valuesToPool.push_back(valueToPool);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
|
|
||||||
// assert(computeOpsForPooling.size() != 1 &&
|
|
||||||
// "Pooling computed on one tiles make no sense??? Or maybe
|
|
||||||
// this " "should have been simplified earlier???");
|
|
||||||
|
|
||||||
std::function<Value(const Value&)> postProcessFn = nullptr;
|
|
||||||
if (hasPostProcessPoolingWindow<PoolOp>()) {
|
|
||||||
postProcessFn = [&](const Value prevFinalRes) {
|
|
||||||
return postProcessPoolingWindow(
|
|
||||||
rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
Value reducedWithinCompute = applyReducePatternNew(
|
|
||||||
valuesToPool,
|
|
||||||
rewriter,
|
|
||||||
[&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); },
|
|
||||||
nullptr,
|
|
||||||
postProcessFn);
|
|
||||||
|
|
||||||
// Send this value through a channel, and receive it in the
|
|
||||||
// `func.func`. During lowering, we will need to "move it" into the
|
|
||||||
// users computeOps
|
|
||||||
auto computeOpOfReduced =
|
|
||||||
cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
|
|
||||||
|
|
||||||
// Create a new channel before the computeOp
|
|
||||||
rewriter.setInsertionPoint(computeOpOfReduced);
|
|
||||||
auto reduceChannel =
|
|
||||||
spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext()));
|
|
||||||
|
|
||||||
// Send value through the channel
|
|
||||||
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
|
|
||||||
spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute);
|
|
||||||
|
|
||||||
// Receive after the computeOp
|
|
||||||
rewriter.setInsertionPointAfter(computeOpOfReduced);
|
|
||||||
auto receivedValue =
|
|
||||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel);
|
|
||||||
|
|
||||||
outputTiles[outTile][outX][outY] = receivedValue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: outputTiles are not the results of the computeOps! We need to add
|
|
||||||
// them!
|
|
||||||
|
|
||||||
std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
|
|
||||||
|
|
||||||
// Iterate each output tile
|
|
||||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
|
||||||
// Iterate each output pixel
|
|
||||||
for (size_t outX = 0; outX < output_w; outX++) {
|
|
||||||
for (size_t outY = 0; outY < output_h; outY++) {
|
|
||||||
auto outputTile = outputTiles[outTile][outX][outY];
|
|
||||||
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
|
|
||||||
if (!outputTileProducer) {
|
|
||||||
return rewriter.notifyMatchFailure(poolOp,
|
|
||||||
"Output tile for Pooling is not produced by a "
|
|
||||||
"WeightedComputeOp.");
|
|
||||||
}
|
|
||||||
|
|
||||||
computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
|
|
||||||
|
|
||||||
rewriter.replaceOp(poolOp, outputImage);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
||||||
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
|
|
||||||
ctx);
|
|
||||||
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -27,9 +27,21 @@ def spatToPimMVMOp : Pat<
|
|||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def spatToPimVAddOp : Pat<
|
def spatToPimVVAddOp : Pat<
|
||||||
(SpatVAddOp:$srcOpRes $a, $b),
|
(SpatVAddOp:$srcOpRes $a, $b),
|
||||||
(PimVAddOp $a, $b,
|
(PimVVAddOp $a, $b,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
|
>;
|
||||||
|
|
||||||
|
def spatToPimVVMulOp : Pat<
|
||||||
|
(SpatVMulOp:$srcOpRes $a, $b),
|
||||||
|
(PimVVMulOp $a, $b,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
|
>;
|
||||||
|
|
||||||
|
def spatToPimVVMaxOp : Pat<
|
||||||
|
(SpatVMaxOp:$srcOpRes $a, $b),
|
||||||
|
(PimVVMaxOp $a, $b,
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
|||||||
@@ -251,7 +251,7 @@ def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
|
def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> {
|
||||||
let description = [{
|
let description = [{
|
||||||
Element-wise addition: c = a + b
|
Element-wise addition: c = a + b
|
||||||
}];
|
}];
|
||||||
@@ -277,7 +277,59 @@ def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> {
|
||||||
|
let description = [{
|
||||||
|
Element-wise subtraction: c = a - b
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor: $a,
|
||||||
|
PimTensor: $b,
|
||||||
|
PimTensor: $outBuf
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor: $outRes
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutBufMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> {
|
||||||
|
let description = [{
|
||||||
|
Element-wise multiplication: c = a * b
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor: $a,
|
||||||
|
PimTensor: $b,
|
||||||
|
PimTensor: $outBuf
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor: $outRes
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutBufMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> {
|
||||||
let description = [{
|
let description = [{
|
||||||
Element-wise max: c = max(a, b)
|
Element-wise max: c = max(a, b)
|
||||||
}];
|
}];
|
||||||
@@ -291,6 +343,32 @@ def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterfac
|
|||||||
let results = (outs
|
let results = (outs
|
||||||
PimTensor: $outRes
|
PimTensor: $outRes
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutBufMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||||
|
let description = [{
|
||||||
|
Dot product: c = dot(a, b)
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor: $a,
|
||||||
|
PimTensor: $b,
|
||||||
|
PimTensor: $outBuf
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor: $outRes
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||||
@@ -332,14 +410,13 @@ def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||||
let description = [{
|
let description = [{
|
||||||
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)
|
Average all elements into a single one
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
PimTensor: $dividend,
|
PimTensor: $a,
|
||||||
PimTensor: $divisor,
|
|
||||||
PimTensor: $outBuf
|
PimTensor: $outBuf
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -363,9 +440,24 @@ def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterf
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||||
let description = [{
|
let description = [{
|
||||||
Element-wise exp: c = exp(a)
|
Element-wise tanh activation
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor: $a,
|
||||||
|
PimTensor: $outBuf
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor: $outRes
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||||
|
let description = [{
|
||||||
|
Element-wise sigmoid activation
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@@ -388,4 +480,4 @@ def PimHaltOp: PimOp<"halt", [Terminator]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // PIM_DIALECT_H
|
#endif // PIM_DIALECT_H
|
||||||
|
|||||||
@@ -30,12 +30,13 @@ void PimDialect::initialize() {
|
|||||||
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
|
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
|
||||||
}
|
}
|
||||||
|
|
||||||
POPULATE_DEPENDENCIES(PimVMaxOp)
|
POPULATE_DEPENDENCIES(PimVVDMulOp)
|
||||||
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
|
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
|
||||||
POPULATE_DEPENDENCIES(PimSumOp)
|
POPULATE_DEPENDENCIES(PimSumOp)
|
||||||
POPULATE_DEPENDENCIES(PimVSDivOp)
|
POPULATE_DEPENDENCIES(PimVAvgOp)
|
||||||
POPULATE_DEPENDENCIES(PimVReluOp)
|
POPULATE_DEPENDENCIES(PimVReluOp)
|
||||||
POPULATE_DEPENDENCIES(PimVExpOp)
|
POPULATE_DEPENDENCIES(PimVTanhOp)
|
||||||
|
POPULATE_DEPENDENCIES(PimVSigmOp)
|
||||||
|
|
||||||
} // namespace pim
|
} // namespace pim
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
#include "OpBufferizationInterfaces.hpp"
|
#include "OpBufferizationInterfaces.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -12,6 +13,26 @@ using namespace bufferization;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace pim {
|
namespace pim {
|
||||||
|
|
||||||
|
static Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||||
|
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||||
|
return memrefValue;
|
||||||
|
|
||||||
|
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||||
|
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||||
|
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||||
|
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
|
return PimMemCopyOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
contiguousType,
|
||||||
|
contiguousBuffer,
|
||||||
|
memrefValue,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||||
|
.getDstOut();
|
||||||
|
}
|
||||||
|
|
||||||
struct MemCopyHostToDevOpInterface
|
struct MemCopyHostToDevOpInterface
|
||||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
||||||
LogicalResult bufferize(Operation* op,
|
LogicalResult bufferize(Operation* op,
|
||||||
@@ -164,7 +185,8 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> {
|
template <typename OpTy>
|
||||||
|
struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
}
|
}
|
||||||
@@ -179,21 +201,24 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
|
|||||||
RewriterBase& rewriter,
|
RewriterBase& rewriter,
|
||||||
const BufferizationOptions& options,
|
const BufferizationOptions& options,
|
||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto vaddOp = cast<PimVAddOp>(op);
|
auto binaryOp = cast<OpTy>(op);
|
||||||
|
|
||||||
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);
|
auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state);
|
||||||
if (failed(aOpt))
|
if (failed(aOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state);
|
auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state);
|
||||||
if (failed(bOpt))
|
if (failed(bOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state);
|
auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state);
|
||||||
if (failed(outBufOpt))
|
if (failed(outBufOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt);
|
Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter);
|
||||||
|
Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -205,7 +230,10 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
|||||||
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
|
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
|
||||||
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
|
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
|
||||||
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
|
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
|
||||||
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
|
PimVVAddOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVAddOp>>(*ctx);
|
||||||
|
PimVVSubOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVSubOp>>(*ctx);
|
||||||
|
PimVVMulOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMulOp>>(*ctx);
|
||||||
|
PimVVMaxOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMaxOp>>(*ctx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||||
@@ -36,6 +37,25 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
|
|||||||
return memref::AllocOp::create(rewriter, loc, memrefResultType);
|
return memref::AllocOp::create(rewriter, loc, memrefResultType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||||
|
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||||
|
return memrefValue;
|
||||||
|
|
||||||
|
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||||
|
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
|
||||||
|
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
|
return pim::PimMemCopyOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
contiguousBuffer.getType(),
|
||||||
|
contiguousBuffer,
|
||||||
|
memrefValue,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||||
|
.getDstOut();
|
||||||
|
}
|
||||||
|
|
||||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||||
|
|
||||||
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
|
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
|
||||||
@@ -167,7 +187,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
|
|||||||
auto memref = getBuffer(rewriter, operand, options, state);
|
auto memref = getBuffer(rewriter, operand, options, state);
|
||||||
if (failed(memref))
|
if (failed(memref))
|
||||||
return failure();
|
return failure();
|
||||||
memrefOperands.push_back(*memref);
|
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Support addiction with more than 2 operands
|
// TODO: Support addiction with more than 2 operands
|
||||||
@@ -460,7 +480,7 @@ struct ChannelBroadcastSendOpInterface
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct VAddOpInterfaceFromTemplate
|
struct VAddOpInterfaceFromTemplate
|
||||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVAddOp> {};
|
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
|
||||||
|
|
||||||
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
|
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
|
||||||
|
|
||||||
@@ -468,9 +488,7 @@ struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, Spa
|
|||||||
|
|
||||||
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
|
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
|
||||||
|
|
||||||
struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface<VSDivOpInterface, SpatVSDivOp, pim::PimVSDivOp> {};
|
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||||
|
|
||||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVMaxOp> {};
|
|
||||||
|
|
||||||
// Create a new bufferizable op interface for the apply filters operation.
|
// Create a new bufferizable op interface for the apply filters operation.
|
||||||
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
|
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
|
||||||
@@ -557,7 +575,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
|||||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
||||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
||||||
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
|
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
|
||||||
SpatVSDivOp::attachInterface<VSDivOpInterface>(*ctx);
|
|
||||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
||||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
||||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||||
@@ -569,12 +586,16 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
|||||||
|
|
||||||
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
|
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
|
||||||
|
|
||||||
struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface<ONNXExpOpInterface, ONNXExpOp, pim::PimVExpOp> {};
|
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
|
||||||
|
|
||||||
|
struct ONNXSigmoidInterface
|
||||||
|
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
|
||||||
|
|
||||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||||
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
|
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
|
||||||
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
|
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
|
||||||
ONNXExpOp::attachInterface<ONNXExpOpInterface>(*ctx);
|
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
|
||||||
|
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ add_pim_library(OMPimPasses
|
|||||||
PimConstantFolding/Patterns/Constant.cpp
|
PimConstantFolding/Patterns/Constant.cpp
|
||||||
PimConstantFolding/PimConstantFoldingPass.cpp
|
PimConstantFolding/PimConstantFoldingPass.cpp
|
||||||
PimConstantFolding/Patterns/Subview.cpp
|
PimConstantFolding/Patterns/Subview.cpp
|
||||||
PimHostVerificationPass.cpp
|
PimMaterializeConstantsPass.cpp
|
||||||
|
PimVerificationPass.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -120,20 +120,8 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
|||||||
rewriter.setInsertionPoint(coreOp);
|
rewriter.setInsertionPoint(coreOp);
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||||
|
|
||||||
size_t elementByteWidth = initType.getElementTypeBitWidth() / 8;
|
|
||||||
if (elementByteWidth == 0)
|
|
||||||
return failure();
|
|
||||||
size_t totalBytes = initType.getNumElements() * elementByteWidth;
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(mapOp);
|
rewriter.setInsertionPoint(mapOp);
|
||||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
|
||||||
mapOp.getLoc(),
|
|
||||||
initType,
|
|
||||||
mapOp.getInit(),
|
|
||||||
getGlobalOp.getResult(),
|
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
|
|
||||||
rewriter.eraseOp(mapOp);
|
rewriter.eraseOp(mapOp);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|||||||
135
src/PIM/Pass/PimMaterializeConstantsPass.cpp
Normal file
135
src/PIM/Pass/PimMaterializeConstantsPass.cpp
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/MathExtras.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||||
|
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||||
|
return operandIndex == 1;
|
||||||
|
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||||
|
return operandIndex == 0;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int64_t getValueSizeInBytes(Value value) {
|
||||||
|
auto type = dyn_cast<ShapedType>(value.getType());
|
||||||
|
if (!type || !type.hasStaticShape())
|
||||||
|
return -1;
|
||||||
|
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PimMaterializeConstantsPass
|
||||||
|
: PassWrapper<PimMaterializeConstantsPass, OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass)
|
||||||
|
|
||||||
|
StringRef getArgument() const override { return "materialize-pim-constants"; }
|
||||||
|
StringRef getDescription() const override {
|
||||||
|
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
ModuleOp moduleOp = getOperation();
|
||||||
|
OpBuilder rewriter(moduleOp.getContext());
|
||||||
|
bool hasFailure = false;
|
||||||
|
|
||||||
|
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||||
|
if (funcOp.isExternal())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||||
|
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||||
|
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) {
|
||||||
|
if (isa<pim::PimHaltOp>(op))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (OpOperand& operand : op.getOpOperands()) {
|
||||||
|
Value originalValue = operand.get();
|
||||||
|
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber()))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto resolvedAddress = resolveContiguousAddress(originalValue);
|
||||||
|
if (failed(resolvedAddress))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||||
|
if (!getGlobalOp)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
|
||||||
|
if (!originalType || !originalType.hasStaticShape()) {
|
||||||
|
op.emitOpError("host constant materialization requires a static memref operand");
|
||||||
|
hasFailure = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& cachedByOffset = materializedValues[resolvedAddress->base];
|
||||||
|
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
|
||||||
|
auto cachedValue = cachedByType.find(originalType);
|
||||||
|
if (cachedValue != cachedByType.end()) {
|
||||||
|
operand.set(cachedValue->second);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t totalBytes = getValueSizeInBytes(originalValue);
|
||||||
|
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
|
||||||
|
op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
|
||||||
|
hasFailure = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(&op);
|
||||||
|
Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType);
|
||||||
|
Value deviceDst = localAlloc;
|
||||||
|
if (contiguousType != originalType)
|
||||||
|
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
|
||||||
|
|
||||||
|
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||||
|
op.getLoc(),
|
||||||
|
originalType,
|
||||||
|
deviceDst,
|
||||||
|
getGlobalOp.getResult(),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(
|
||||||
|
static_cast<int32_t>(resolvedAddress->byteOffset)),
|
||||||
|
rewriter.getI32IntegerAttr(
|
||||||
|
static_cast<int32_t>(totalBytes)));
|
||||||
|
|
||||||
|
cachedByType[originalType] = hostToDevCopy.getResult();
|
||||||
|
operand.set(hostToDevCopy.getResult());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasFailure) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dumpModule(moduleOp, "pim3_materialized");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createPimMaterializeConstantsPass() {
|
||||||
|
return std::make_unique<PimMaterializeConstantsPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -17,7 +17,9 @@ std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
|||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimHostVerificationPass();
|
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
||||||
|
|
||||||
|
|||||||
@@ -35,16 +35,24 @@ static bool isCodegenAddressableValue(Value value) {
|
|||||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
|
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass)
|
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||||
|
return operandIndex == 1;
|
||||||
|
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||||
|
return operandIndex == 0;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
StringRef getArgument() const override { return "verify-pim-host-pass"; }
|
struct PimVerificationPass : PassWrapper<PimVerificationPass, OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass)
|
||||||
|
|
||||||
|
StringRef getArgument() const override { return "verify-pim-pass"; }
|
||||||
StringRef getDescription() const override {
|
StringRef getDescription() const override {
|
||||||
return "Verify that no runtime host-side code remains in bufferized PIM IR";
|
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
|
||||||
}
|
}
|
||||||
|
|
||||||
PimHostVerificationPass() {}
|
PimVerificationPass() {}
|
||||||
PimHostVerificationPass(const PimHostVerificationPass& pass) {}
|
PimVerificationPass(const PimVerificationPass& pass) {}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
@@ -132,11 +140,27 @@ private:
|
|||||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||||
if (!isa<BaseMemRefType>(operand.getType()))
|
if (!isa<BaseMemRefType>(operand.getType()))
|
||||||
continue;
|
continue;
|
||||||
if (succeeded(resolveContiguousAddress(operand)))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
auto resolvedAddress = resolveContiguousAddress(operand);
|
||||||
hasFailure = true;
|
if (failed(resolvedAddress)) {
|
||||||
|
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||||
|
hasFailure = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||||
|
if (!isCodegenAddressableValue(operand)) {
|
||||||
|
op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||||
|
op.emitOpError() << "operand #" << operandIndex
|
||||||
|
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return success(!hasFailure);
|
return success(!hasFailure);
|
||||||
@@ -165,6 +189,6 @@ private:
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimHostVerificationPass() { return std::make_unique<PimHostVerificationPass>(); }
|
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<PimVerificationPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -74,7 +74,8 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
|||||||
registerPass(createSpatialToPimPass);
|
registerPass(createSpatialToPimPass);
|
||||||
registerPass(createBufferizePimPass);
|
registerPass(createBufferizePimPass);
|
||||||
registerPass(createPimConstantFoldingPass);
|
registerPass(createPimConstantFoldingPass);
|
||||||
registerPass(createPimHostVerificationPass);
|
registerPass(createPimMaterializeConstantsPass);
|
||||||
|
registerPass(createPimVerificationPass);
|
||||||
registerPass(createEmitPimJsonPass);
|
registerPass(createEmitPimJsonPass);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,18 @@ python3 validation/operations/gen_tests.py
|
|||||||
| With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias |
|
| With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias |
|
||||||
| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input |
|
| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input |
|
||||||
|
|
||||||
|
## Pool
|
||||||
|
|
||||||
|
| Test | Directory | Input | Output | Kernel | Stride | Padding | Notes |
|
||||||
|
|------|-----------|-------|--------|--------|--------|---------|-------|
|
||||||
|
| Max basic | `pool/max_basic` | [1,1,4,4] | [1,1,3,3] | 2x2 | 1 | none | Basic max pooling |
|
||||||
|
| Max stride 2 multi-channel | `pool/max_stride2_multichannel` | [1,5,6,6] | [1,5,3,3] | 2x2 | 2 | none | Channel-preserving max pool |
|
||||||
|
| Max SAME_UPPER | `pool/max_same_upper` | [1,1,5,5] | [1,1,3,3] | 3x3 | 2 | SAME_UPPER | Deprecated auto_pad path |
|
||||||
|
| Avg basic | `pool/avg_basic` | [1,3,4,4] | [1,3,3,3] | 2x2 | 1 | none | Basic average pooling |
|
||||||
|
| Avg explicit padding | `pool/avg_explicit_padding` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=0` |
|
||||||
|
| Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` |
|
||||||
|
| Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` |
|
||||||
|
|
||||||
## Gemm
|
## Gemm
|
||||||
|
|
||||||
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Generate ONNX test models for validating GEMM and Conv implementations."""
|
"""Generate ONNX test models for validating GEMM, Conv, and Pooling implementations."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnx
|
import onnx
|
||||||
@@ -248,6 +248,85 @@ def conv_large_spatial():
|
|||||||
save_model(model, "conv/large_spatial", "conv_large_spatial.onnx")
|
save_model(model, "conv/large_spatial", "conv_large_spatial.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pooling tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def maxpool_basic():
|
||||||
|
"""MaxPool 2x2 with stride 1."""
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
|
||||||
|
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||||
|
graph = helper.make_graph([node], "maxpool_basic", [X], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "pool/max_basic", "maxpool_basic.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def maxpool_stride2_multichannel():
|
||||||
|
"""MaxPool 2x2 with stride 2 on multiple channels."""
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 5, 6, 6])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 5, 3, 3])
|
||||||
|
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0])
|
||||||
|
graph = helper.make_graph([node], "maxpool_stride2_multichannel", [X], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "pool/max_stride2_multichannel", "maxpool_stride2_multichannel.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def maxpool_same_upper():
|
||||||
|
"""MaxPool 3x3 with SAME_UPPER padding."""
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
|
||||||
|
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[3, 3], strides=[2, 2], auto_pad="SAME_UPPER")
|
||||||
|
graph = helper.make_graph([node], "maxpool_same_upper", [X], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "pool/max_same_upper", "maxpool_same_upper.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def avgpool_basic():
|
||||||
|
"""AveragePool 2x2 with stride 1."""
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 3, 3])
|
||||||
|
node = helper.make_node("AveragePool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||||
|
graph = helper.make_graph([node], "avgpool_basic", [X], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "pool/avg_basic", "avgpool_basic.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def avgpool_explicit_padding():
|
||||||
|
"""AveragePool 3x3 with explicit padding, excluding pad from the divisor."""
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2])
|
||||||
|
node = helper.make_node("AveragePool", ["X"], ["Y"],
|
||||||
|
kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=0)
|
||||||
|
graph = helper.make_graph([node], "avgpool_explicit_padding", [X], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "pool/avg_explicit_padding", "avgpool_explicit_padding.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def avgpool_include_pad():
|
||||||
|
"""AveragePool 3x3 with explicit padding, including pad in the divisor."""
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2])
|
||||||
|
node = helper.make_node("AveragePool", ["X"], ["Y"],
|
||||||
|
kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=1)
|
||||||
|
graph = helper.make_graph([node], "avgpool_include_pad", [X], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "pool/avg_include_pad", "avgpool_include_pad.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def maxpool_after_conv():
|
||||||
|
"""Conv followed by MaxPool to validate pooling on lowered conv results."""
|
||||||
|
rng = np.random.default_rng(59)
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 6, 6])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 2, 2])
|
||||||
|
W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
|
||||||
|
conv = helper.make_node("Conv", ["X", "W"], ["C"], kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||||
|
pool = helper.make_node("MaxPool", ["C"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0])
|
||||||
|
graph = helper.make_graph([conv, pool], "maxpool_after_conv", [X], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Main
|
# Main
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -273,4 +352,13 @@ if __name__ == "__main__":
|
|||||||
conv_batch_2()
|
conv_batch_2()
|
||||||
conv_large_spatial()
|
conv_large_spatial()
|
||||||
|
|
||||||
|
print("\nGenerating Pooling tests:")
|
||||||
|
maxpool_basic()
|
||||||
|
maxpool_stride2_multichannel()
|
||||||
|
maxpool_same_upper()
|
||||||
|
avgpool_basic()
|
||||||
|
avgpool_explicit_padding()
|
||||||
|
avgpool_include_pad()
|
||||||
|
maxpool_after_conv()
|
||||||
|
|
||||||
print("\nDone.")
|
print("\nDone.")
|
||||||
|
|||||||
BIN
validation/operations/pool/avg_basic/avgpool_basic.onnx
Normal file
BIN
validation/operations/pool/avg_basic/avgpool_basic.onnx
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
validation/operations/pool/max_basic/maxpool_basic.onnx
Normal file
BIN
validation/operations/pool/max_basic/maxpool_basic.onnx
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user