diff --git a/backend-simulators/pim/pim-simulator/src/lib/instruction_set/isa.rs b/backend-simulators/pim/pim-simulator/src/lib/instruction_set/isa.rs index fc9d5be..2eff04b 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/instruction_set/isa.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/instruction_set/isa.rs @@ -52,6 +52,7 @@ static NAMES: LazyLock> = LazyLock::new(|| { add_name_simd!(hash, vrelu); add_name_simd!(hash, vtanh); add_name_simd!(hash, vsigm); + add_name_simd!(hash, vsoftmax); add_name!(hash, vmv); add_name!(hash, vrsu); add_name!(hash, vrsl); @@ -177,6 +178,7 @@ static SIMD: LazyLock>> add_simd_to_map!(storage, vrelu); add_simd_to_map!(storage, vtanh); add_simd_to_map!(storage, vsigm); + add_simd_to_map!(storage, vsoftmax); add_simd_to_map!(storage, mvmul); storage }); @@ -626,6 +628,46 @@ where Ok(InstructionStatus::Completed) } +pub fn vsoftmax(cores: &mut CPU, data: InstructionData) -> Result { + panic!("You are calling a placeholder, the real call is the generic version"); +} + +pub(super) fn vsoftmax_impl(cores: &mut CPU, data: InstructionData) -> Result +where + [F]: UpcastSlice, + T: UpcastDestTraits + MemoryStorable, + F: UpcastDestTraits + MemoryStorable + From, +{ + TRACER.lock().unwrap().pre_vsoftmax::(cores, data); + let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = + data.get_core_rd_r1_r2_immlen_offset(); + let core = cores.core(core_indx); + let r1_val = core.register(r1); + let rd_val = core.register(rd); + let r1_val = add_offset_r1(r1_val, offset_select, offset_value); + let rd_val = add_offset_rd(rd_val, offset_select, offset_value); + let loads = core.reserve_load(r1_val, imm_len)?.execute_load::()?; + let load1 = loads[0]; + ensure!(!load1.is_empty(), "vsoftmax does not support empty vectors"); + let max_val = load1 + .iter() + .copied() + .reduce(|a, b| if a > b { a } else { b }) + .unwrap(); + let exp_values: Vec = load1.iter().map(|&a| (a - max_val).exp()).collect(); + let sum = exp_values + .iter() + .copied() + .reduce(|a, b| a + b) + .unwrap(); + ensure!(sum > 0.0.into(), "vsoftmax normalization sum must be positive"); + let res: Vec = exp_values.iter().map(|&a| a / sum).collect(); + let res_up: Cow<[T]> = res.as_slice().up(); + core.execute_store(rd_val, res_up.as_ref()); + TRACER.lock().unwrap().post_vsoftmax::(cores, data); + Ok(InstructionStatus::Completed) +} + pub fn vmv(cores: &mut CPU, data: InstructionData) -> Result { todo!() } diff --git a/backend-simulators/pim/pim-simulator/src/lib/json_to_instruction/json_isa.rs b/backend-simulators/pim/pim-simulator/src/lib/json_to_instruction/json_isa.rs index f2885b8..e20d3c9 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/json_to_instruction/json_isa.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/json_to_instruction/json_isa.rs @@ -40,6 +40,7 @@ static SIMD: LazyLock> = LazyLock::new(|| { add_to_json_map!(storage, vrelu); add_to_json_map!(storage, vtanh); add_to_json_map!(storage, vsigm); + add_to_json_map!(storage, vsoftmax); add_to_json_map!(storage, vmv); add_to_json_map!(storage, vrsu); add_to_json_map!(storage, vrsl); @@ -417,6 +418,27 @@ fn json_to_vsigm( Ok(()) } +fn json_to_vsoftmax( + inst_builder: &mut InstructionsBuilder, + inst_data_builder: &mut InstructionDataBuilder, + json: &Value, +) -> Result<()> { + let json = json.as_object().expect("Not an object"); + assert_eq!("vsoftmax", json_str!(json, "op")); + let rd = json_i64!(json, "rd") as i32; + let rs1 = json_i64!(json, "rs1") as i32; + let len = json_i64!(json, "len") as i32; + let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap()); + inst_data_builder + .set_rd(rd) + .set_r1(rs1) + .set_imm_len(len) + .set_offset_select(offset_select) + .set_offset_value(offset_value); + inst_builder.make_inst(vsoftmax, inst_data_builder.build()); + Ok(()) +} + fn json_to_vmv( inst_builder: &mut InstructionsBuilder, inst_data_builder: &mut InstructionDataBuilder, diff --git a/backend-simulators/pim/pim-simulator/src/lib/memory_manager/type_traits.rs b/backend-simulators/pim/pim-simulator/src/lib/memory_manager/type_traits.rs index 151209e..6778dfa 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/memory_manager/type_traits.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/memory_manager/type_traits.rs @@ -67,6 +67,22 @@ impl HasSigm for f64 { } } +pub trait HasExp { + fn exp(self) -> Self; +} + +impl HasExp for f32 { + fn exp(self) -> Self { + self.exp() + } +} + +impl HasExp for f64 { + fn exp(self) -> Self { + self.exp() + } +} + pub trait TryToUsize: TryInto @@ -112,6 +128,7 @@ pub trait UpcastDestTraits: + PartialOrd + HasTanh + HasSigm + + HasExp + FromUsize { } diff --git a/backend-simulators/pim/pim-simulator/src/lib/tracing/disable.rs b/backend-simulators/pim/pim-simulator/src/lib/tracing/disable.rs index 47729ad..f31bf38 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/tracing/disable.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/tracing/disable.rs @@ -248,6 +248,22 @@ impl Trace { { } + pub fn pre_vsoftmax(&mut self, cores: &mut CPU, data: InstructionData) + where + [F]: UpcastSlice, + T: UpcastDestTraits + MemoryStorable, + F: UpcastDestTraits + MemoryStorable + From, + { + } + + pub fn post_vsoftmax(&mut self, cores: &mut CPU, data: InstructionData) + where + [F]: UpcastSlice, + T: UpcastDestTraits + MemoryStorable, + F: UpcastDestTraits + MemoryStorable + From, + { + } + ///////////////////////////////////////////////////////////////// /////Communication/synchronization Instructions///////////////// ///////////////////////////////////////////////////////////////// diff --git a/backend-simulators/pim/pim-simulator/src/lib/tracing/tracing_isa.rs b/backend-simulators/pim/pim-simulator/src/lib/tracing/tracing_isa.rs index 55026ca..6d0177b 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/tracing/tracing_isa.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/tracing/tracing_isa.rs @@ -956,6 +956,35 @@ impl Trace { // Ok(InstructionStatus::Completed) } + pub fn pre_vsoftmax(&mut self, cores: &mut CPU, data: InstructionData) + where + [F]: UpcastSlice, + T: UpcastDestTraits + MemoryStorable, + F: UpcastDestTraits + MemoryStorable + From, + { + let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = + data.get_core_rd_r1_r2_immlen_offset(); + let file: &mut File = self + .out_files + .get_mut(core_indx as usize) + .expect("File at index not found"); + writeln!(file, "\t\tVSOFTMAX\t\t"); + } + + pub fn post_vsoftmax(&mut self, cores: &mut CPU, data: InstructionData) + where + [F]: UpcastSlice, + T: UpcastDestTraits + MemoryStorable, + F: UpcastDestTraits + MemoryStorable + From, + { + let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = + data.get_core_rd_r1_r2_immlen_offset(); + let file: &mut File = self + .out_files + .get_mut(core_indx as usize) + .expect("File at index not found"); + } + ///////////////////////////////////////////////////////////////// /////Communication/synchronization Instructions///////////////// ///////////////////////////////////////////////////////////////// diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index d799c46..c48de47 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -386,6 +386,20 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const { emitInstruction(std::move(json)); } +void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const { + auto outputBufferAddr = memory.getValueAddress(vsoftmaxOp.getOutputBuffer()); + auto inputAddr = memory.getValueAddress(vsoftmaxOp.getInput()); + setupRdRs1(outputBufferAddr, 0, inputAddr, 0); + + json::Object json; + json["op"] = "vsoftmax"; + json["rd"] = 0; + json["rs1"] = 1; + json["offset"] = createEmptyOffset(); + json["len"] = getValueSizeInBytes(vsoftmaxOp.getInput()); + emitInstruction(std::move(json)); +} + void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const { auto srcAddr = memory.getValueAddress(transposeOp.getInput()); auto dstAddr = memory.getValueAddress(transposeOp.getOutputBuffer()); @@ -537,6 +551,8 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenVTanhOp(vtanhOp); else if (auto vsigmOp = dyn_cast(op)) coreCodeGen.codeGenVSigmOp(vsigmOp); + else if (auto vsoftmaxOp = dyn_cast(op)) + coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp); else { op.emitError("Unsupported codegen for this operation"); op.dump(); diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index ab0f1a0..81f340d 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -99,6 +99,7 @@ public: void codeGenVReluOp(pim::PimVReluOp vreluOp) const; void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const; void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const; + void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const; void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const; }; diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 1e3b3ae..d26673f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -11,8 +11,12 @@ add_pim_library(OMONNXToSpatial Patterns/NN/Pool.cpp Patterns/NN/Relu.cpp Patterns/NN/Sigmoid.cpp + Patterns/NN/Softmax.cpp Patterns/Tensor/Concat.cpp + Patterns/Tensor/Gather.cpp + Patterns/Tensor/Resize.cpp Patterns/Tensor/Reshape.cpp + Patterns/Tensor/Split.cpp ONNXToSpatialPass.cpp Common.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 1649c34..d140ee2 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -89,9 +89,12 @@ void ONNXToSpatialPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); RewritePatternSet patterns(ctx); patterns.add(ctx); @@ -103,8 +106,12 @@ void ONNXToSpatialPass::runOnOperation() { populateReduceMeanPatterns(patterns, ctx); populateReluPatterns(patterns, ctx); populateSigmoidPatterns(patterns, ctx); + populateSoftmaxPatterns(patterns, ctx); populateConcatPatterns(patterns, ctx); + populateGatherPatterns(patterns, ctx); + populateResizePatterns(patterns, ctx); populateReshapePatterns(patterns, ctx); + populateSplitPatterns(patterns, ctx); if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); @@ -168,7 +175,7 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources); llvm::SmallVector sourceTypes; llvm::SmallVector sourceLoc; - for (auto source : sources){ + for (auto source : sources) { sourceTypes.push_back(source.getType()); sourceLoc.push_back(loc); } @@ -176,7 +183,7 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()}); rewriter.setInsertionPointToEnd(BB); IRMapping mapper; - for(auto [source,bbArg] : llvm::zip(sources, BB->getArguments())) + for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments())) mapper.map(source, bbArg); auto newConcat = rewriter.clone(*inst, mapper); spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0)); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp index 38232ba..7c44286 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp @@ -21,8 +21,16 @@ void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + +void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp index 21221e5..35c8e7d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp @@ -92,10 +92,8 @@ static FailureOr materializeBroadcastedConstantTensor(Value value, return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult(); } -static FailureOr prepareElementwiseOperand(Value value, - RankedTensorType resultType, - ConversionPatternRewriter& rewriter, - Location loc) { +static FailureOr +prepareElementwiseOperand(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { auto valueType = dyn_cast(value.getType()); if (!valueType || !valueType.hasStaticShape()) return failure(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 0e87f7a..fe0d1cb 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -280,8 +280,8 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++) weights.push_back(bTiles[outSliceId][coreId][aSliceId]); - auto computeOp = - createSpatCompute(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) { + auto computeOp = createSpatCompute( + rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) { SmallVector vmmOutputs; vmmOutputs.reserve(aHSlicesArgs.size()); for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp index c78e078..a90f107 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp @@ -71,10 +71,8 @@ static SmallVector buildCollapseReassociation(ArrayRef(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) { auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x); @@ -141,7 +139,8 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern { Location loc = reduceMeanOp.getLoc(); RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType()); - Value reducedKeepdims = buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc); + Value reducedKeepdims = + buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc); if (reduceMeanOp.getKeepdims() != 0) { rewriter.replaceOp(reduceMeanOp, reducedKeepdims); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp new file mode 100644 index 0000000..e141615 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp @@ -0,0 +1,111 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } + +static SmallVector permuteShape(ArrayRef shape, ArrayRef permutation) { + SmallVector permutedShape; + permutedShape.reserve(permutation.size()); + for (int64_t axis : permutation) + permutedShape.push_back(shape[axis]); + return permutedShape; +} + +static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) { + auto inputType = cast(input.getType()); + constexpr size_t numInputs = 1; + auto computeOp = + createSpatCompute(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) { + auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x); + spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult()); + }); + return computeOp.getResult(0); +} + +static Value +buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { + auto inputType = cast(input.getType()); + if (axis == inputType.getRank()) + return createSoftmaxCompute(input, rewriter, loc); + + if (axis == softmaxAxis) + return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc); + + SmallVector slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc); + SmallVector rebuiltSlices; + rebuiltSlices.reserve(slices.size()); + for (Value slice : slices) + rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc)); + + return rebuiltSlices.size() == 1 ? rebuiltSlices.front() + : tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult(); +} + +struct SoftmaxToSpatialCompute : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXSoftmaxOp softmaxOp, + ONNXSoftmaxOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto inputType = dyn_cast(adaptor.getInput().getType()); + if (!inputType || !inputType.hasStaticShape()) + return failure(); + + int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank()); + if (axis < 0 || axis >= inputType.getRank()) + return failure(); + + Value input = adaptor.getInput(); + Value result; + if (axis == inputType.getRank() - 1) { + result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc()); + } + else { + SmallVector permutation; + permutation.reserve(inputType.getRank()); + for (int64_t dim = 0; dim < inputType.getRank(); ++dim) + if (dim != axis) + permutation.push_back(dim); + permutation.push_back(axis); + + SmallVector inversePermutation(inputType.getRank()); + for (auto [newIndex, oldIndex] : llvm::enumerate(permutation)) + inversePermutation[oldIndex] = static_cast(newIndex); + + auto transposedType = RankedTensorType::get( + permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); + auto preTransposeCompute = + createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) { + Value transposed = ONNXTransposeOp::create( + rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation)); + spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed); + }); + Value transposedInput = preTransposeCompute.getResult(0); + Value transposedResult = buildSoftmax( + transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc()); + result = ONNXTransposeOp::create( + rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation)); + } + + rewriter.replaceOp(softmaxOp, result); + return success(); + } +}; + +} // namespace + +void populateSoftmaxPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.add(ctx); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp index 84271a8..a32c551 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp @@ -23,8 +23,6 @@ struct Concat : public OpConversionPattern { } }; -void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert(ctx); -} +void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp new file mode 100644 index 0000000..6605dc1 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp @@ -0,0 +1,157 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } + +static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; } + +static Value +extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) { + auto inputType = cast(input.getType()); + SmallVector offsets(inputType.getRank(), rewriter.getIndexAttr(0)); + SmallVector sizes; + SmallVector strides(inputType.getRank(), rewriter.getIndexAttr(1)); + sizes.reserve(inputType.getRank()); + for (int64_t dim : inputType.getShape()) + sizes.push_back(rewriter.getIndexAttr(dim)); + offsets[axis] = rewriter.getIndexAttr(offset); + sizes[axis] = rewriter.getIndexAttr(1); + return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides); +} + +static Value concatGatherSlices(Value data, + int64_t axis, + ArrayRef indices, + int64_t axisDim, + ConversionPatternRewriter& rewriter, + Location loc) { + SmallVector slices; + slices.reserve(indices.size()); + for (int64_t index : indices) { + int64_t normalizedIndex = normalizeIndex(index, axisDim); + if (normalizedIndex < 0 || normalizedIndex >= axisDim) + return {}; + slices.push_back(extractSliceAt(data, axis, normalizedIndex, rewriter, loc)); + } + if (slices.empty()) + return {}; + return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult(); +} + +static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { + auto valueType = cast(value.getType()); + SmallVector resultShape; + SmallVector reassociation; + resultShape.reserve(valueType.getRank() + 1); + reassociation.reserve(valueType.getRank()); + + int64_t resultDim = 0; + for (int64_t dim = 0; dim < valueType.getRank(); ++dim) { + if (dim == axis) { + resultShape.push_back(1); + resultShape.push_back(valueType.getShape()[dim]); + reassociation.push_back({static_cast(resultDim), static_cast(resultDim + 1)}); + resultDim += 2; + continue; + } + resultShape.push_back(valueType.getShape()[dim]); + reassociation.push_back({static_cast(resultDim)}); + resultDim++; + } + + auto resultType = RankedTensorType::get(resultShape, valueType.getElementType(), valueType.getEncoding()); + return tensor::ExpandShapeOp::create(rewriter, loc, resultType, value, reassociation); +} + +struct Gather : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXGatherOp gatherOp, + ONNXGatherOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto dataType = dyn_cast(adaptor.getData().getType()); + auto indicesType = dyn_cast(adaptor.getIndices().getType()); + if (!dataType || !indicesType || !dataType.hasStaticShape() || !indicesType.hasStaticShape()) + return failure(); + + auto indicesConst = adaptor.getIndices().getDefiningOp(); + if (!indicesConst) + return failure(); + auto indicesAttr = dyn_cast(indicesConst.getValue()); + if (!indicesAttr) + return failure(); + + int64_t rank = dataType.getRank(); + int64_t axis = normalizeAxis(gatherOp.getAxis(), rank); + if (axis < 0 || axis >= rank) + return failure(); + + int64_t axisDim = dataType.getShape()[axis]; + if (axisDim <= 0) + return failure(); + + SmallVector flatIndices(indicesAttr.getValues().begin(), indicesAttr.getValues().end()); + Location loc = gatherOp.getLoc(); + + auto computeOp = + createSpatCompute<1>(rewriter, + loc, + TypeRange {gatherOp.getResult().getType()}, + {}, + adaptor.getData(), + [&](Value data) -> LogicalResult { + Value result; + if (indicesType.getRank() == 1) { + result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc); + } + else if (indicesType.getRank() == 2) { + int64_t rowCount = indicesType.getShape()[0]; + int64_t rowWidth = indicesType.getShape()[1]; + SmallVector rows; + rows.reserve(rowCount); + for (int64_t row = 0; row < rowCount; ++row) { + ArrayRef rowIndices(flatIndices.data() + row * rowWidth, rowWidth); + Value gatheredRow = concatGatherSlices(data, axis, rowIndices, axisDim, rewriter, loc); + if (!gatheredRow) + return failure(); + rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc)); + } + result = rows.size() == 1 + ? rows.front() + : tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult(); + } + else { + return failure(); + } + + if (!result) + return failure(); + spatial::SpatYieldOp::create(rewriter, loc, result); + return success(); + }); + if (failed(computeOp)) + return failure(); + rewriter.replaceOp(gatherOp, computeOp->getResults()); + return success(); + } +}; + +} // namespace + +void populateGatherPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp new file mode 100644 index 0000000..d5e340e --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp @@ -0,0 +1,90 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/STLExtras.h" + +#include + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static Value +extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) { + auto inputType = cast(input.getType()); + SmallVector offsets(inputType.getRank(), rewriter.getIndexAttr(0)); + SmallVector sizes; + SmallVector strides(inputType.getRank(), rewriter.getIndexAttr(1)); + sizes.reserve(inputType.getRank()); + for (int64_t dim : inputType.getShape()) + sizes.push_back(rewriter.getIndexAttr(dim)); + offsets[axis] = rewriter.getIndexAttr(offset); + sizes[axis] = rewriter.getIndexAttr(1); + return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides); +} + +static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) { + return std::min((outputIndex * inputDim) / outputDim, inputDim - 1); +} + +static Value buildNearestResize(Value input, + ArrayRef inputShape, + ArrayRef outputShape, + int64_t axis, + ConversionPatternRewriter& rewriter, + Location loc) { + if (axis == static_cast(outputShape.size())) + return input; + + SmallVector slices; + slices.reserve(outputShape[axis]); + for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) { + int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]); + Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc); + slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc)); + } + + return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult(); +} + +struct Resize : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXResizeOp resizeOp, + ONNXResizeOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto inputType = dyn_cast(adaptor.getX().getType()); + auto resultType = dyn_cast(resizeOp.getY().getType()); + if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + + if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric" + || resizeOp.getNearestMode() != "floor") + return failure(); + + if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; }) + || llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; })) + return failure(); + + auto computeOp = + createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) { + Value result = + buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc()); + spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result); + }); + rewriter.replaceOp(resizeOp, computeOp.getResults()); + return success(); + } +}; + +} // namespace + +void populateResizePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp new file mode 100644 index 0000000..8ed7fcd --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp @@ -0,0 +1,70 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } + +static Value extractSliceAt( + Value input, int64_t axis, int64_t offset, int64_t size, ConversionPatternRewriter& rewriter, Location loc) { + auto inputType = cast(input.getType()); + SmallVector offsets(inputType.getRank(), rewriter.getIndexAttr(0)); + SmallVector sizes; + SmallVector strides(inputType.getRank(), rewriter.getIndexAttr(1)); + sizes.reserve(inputType.getRank()); + for (int64_t dim : inputType.getShape()) + sizes.push_back(rewriter.getIndexAttr(dim)); + offsets[axis] = rewriter.getIndexAttr(offset); + sizes[axis] = rewriter.getIndexAttr(size); + return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides); +} + +struct Split : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ONNXSplitOp splitOp, ONNXSplitOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + auto inputType = dyn_cast(adaptor.getInput().getType()); + if (!inputType || !inputType.hasStaticShape()) + return failure(); + + int64_t rank = inputType.getRank(); + int64_t axis = normalizeAxis(splitOp.getAxis(), rank); + if (axis < 0 || axis >= rank) + return failure(); + + SmallVector outputs; + outputs.reserve(splitOp.getNumResults()); + + int64_t offset = 0; + for (Value result : splitOp.getResults()) { + auto resultType = dyn_cast(result.getType()); + if (!resultType || !resultType.hasStaticShape()) + return failure(); + int64_t sliceSize = resultType.getShape()[axis]; + auto computeOp = + createSpatCompute<1>(rewriter, splitOp.getLoc(), TypeRange {resultType}, {}, adaptor.getInput(), [&](Value x) { + Value output = extractSliceAt(x, axis, offset, sliceSize, rewriter, splitOp.getLoc()); + spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), output); + }); + outputs.push_back(computeOp.getResult(0)); + offset += sliceSize; + } + + rewriter.replaceOp(splitOp, outputs); + return success(); + } +}; + +} // namespace + +void populateSplitPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index ce94a90..a0fbce5 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -63,4 +63,10 @@ def spatToPimVSigm : Pat< (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; +def spatToPimVSoftmax : Pat< + (SpatSoftmaxOp:$srcOpRes $input), + (PimVSoftmaxOp $input, + (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) +>; + #endif // SPATIAL_TO_PIM diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index b527f6d..07148d3 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -618,17 +618,22 @@ void SpatialToPimPass::markOpToRemove(Operation* op) { } void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { - for (auto it : llvm::enumerate(returnOp.getOperands())) { - Operation* returnOperand = it.value().getDefiningOp(); - + SmallVector originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); + for (auto it : llvm::enumerate(originalOperands)) { size_t orderWithinReturn = it.index(); + Operation* returnOperand = it.value().getDefiningOp(); rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); }); Operation* opToErase = returnOperand; while (opToErase) { - bool isExclusivelyOwnedByReturnChain = opToErase->use_empty() || opToErase->hasOneUse(); + bool isExclusivelyOwnedByReturnChain = opToErase->use_empty(); + if (!isExclusivelyOwnedByReturnChain && opToErase->hasOneUse()) { + Operation* onlyUser = *opToErase->getUsers().begin(); + isExclusivelyOwnedByReturnChain = + isa(onlyUser) || isChannelUseChainOp(onlyUser); + } if (!isExclusivelyOwnedByReturnChain) break; diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 300ddd1..2a6d8e9 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -455,4 +455,27 @@ def PimVSigmOp : PimOp<"vsigm", [DestinationStyleOpInterface]> { }]; } +def PimVSoftmaxOp : PimOp<"vsoftmax", [DestinationStyleOpInterface]> { + let summary = "Softmax over the full input vector"; + + let arguments = (ins + PimTensor:$input, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + #endif // PIM_DIALECT_H diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 5d29208..eae40c2 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -273,6 +273,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) { PimVReluOp::attachInterface>(*ctx); PimVTanhOp::attachInterface>(*ctx); PimVSigmOp::attachInterface>(*ctx); + PimVSoftmaxOp::attachInterface>(*ctx); }); } diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Graph.cpp b/src/PIM/Dialect/Spatial/DCPGraph/Graph.cpp index 15325f1..0288b5c 100644 --- a/src/PIM/Dialect/Spatial/DCPGraph/Graph.cpp +++ b/src/PIM/Dialect/Spatial/DCPGraph/Graph.cpp @@ -485,7 +485,7 @@ DCPAnalysisResult GraphDCP::getResult() { size_t i = 0; for (auto node : nodes) { ret.computeToCPUMap[node->getSpatWeightedCompute()] = cpu; - if (i++ == nodes.size() - 1){ + if (i++ == nodes.size() - 1) { ret.isLastComputeOfACpu.insert(node->getSpatWeightedCompute()); ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute(); } diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Task.cpp b/src/PIM/Dialect/Spatial/DCPGraph/Task.cpp index 724475b..d8f0114 100644 --- a/src/PIM/Dialect/Spatial/DCPGraph/Task.cpp +++ b/src/PIM/Dialect/Spatial/DCPGraph/Task.cpp @@ -43,7 +43,5 @@ bool TaskDCP::hasDescendent(TaskDCP* child) { return false; } -//TODO fare qualcosa di sensato -int TaskDCP::computeWeight(GraphDCP* graph, CPU cpu) { - return orig_weight; -} +// TODO fare qualcosa di sensato +int TaskDCP::computeWeight(GraphDCP* graph, CPU cpu) { return orig_weight; } diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Task.hpp b/src/PIM/Dialect/Spatial/DCPGraph/Task.hpp index 748df3f..2368525 100644 --- a/src/PIM/Dialect/Spatial/DCPGraph/Task.hpp +++ b/src/PIM/Dialect/Spatial/DCPGraph/Task.hpp @@ -75,11 +75,11 @@ public: alst = val; } bool hasDescendent(TaskDCP* child); - int64_t Id() const { return (int64_t)spatWeightedCompute.getAsOpaquePointer(); } + int64_t Id() const { return (int64_t) spatWeightedCompute.getAsOpaquePointer(); } bool isCP() const { return alst == aest; } bool isScheduled() const { return scheduledCPU.has_value(); } - onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute(){return spatWeightedCompute;} + onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() { return spatWeightedCompute; } friend std::optional addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight); friend void removeEdge(TaskDCP* parent, TaskDCP* child); diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Uniqueworklist.hpp b/src/PIM/Dialect/Spatial/DCPGraph/Uniqueworklist.hpp index 59d6bbe..e24e3a8 100644 --- a/src/PIM/Dialect/Spatial/DCPGraph/Uniqueworklist.hpp +++ b/src/PIM/Dialect/Spatial/DCPGraph/Uniqueworklist.hpp @@ -71,12 +71,7 @@ public: return true; } - auto begin() { - return storage.begin(); - } - - auto end() { - return storage.end(); - } + auto begin() { return storage.begin(); } + auto end() { return storage.end(); } }; diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Utils.hpp b/src/PIM/Dialect/Spatial/DCPGraph/Utils.hpp index 3ca2b95..276e769 100644 --- a/src/PIM/Dialect/Spatial/DCPGraph/Utils.hpp +++ b/src/PIM/Dialect/Spatial/DCPGraph/Utils.hpp @@ -1,7 +1,9 @@ #pragma once #include "mlir/IR/BuiltinTypeInterfaces.h" + #include "llvm/Support/Casting.h" + #include #include #include @@ -50,10 +52,9 @@ inline int64_t getSpatWeightCompute(onnx_mlir::spatial::SpatWeightedCompute spat int64_t tot = 0; for (auto& region : spatWeightedCompute.getBody()) { for (auto& inst : region) { - for(auto result : inst.getResults()){ - if(auto element = llvm::dyn_cast(result.getType())) - tot += onnx_mlir::getSizeInBytes(element); - } + for (auto result : inst.getResults()) + if (auto element = llvm::dyn_cast(result.getType())) + tot += onnx_mlir::getSizeInBytes(element); } } return tot; diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 12fdd87..c8f419e 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -272,6 +272,22 @@ def SpatSigmoidOp : SpatOp<"sigmoid", []> { }]; } +def SpatSoftmaxOp : SpatOp<"softmax", []> { + let summary = "Softmax over the full input tensor slice"; + + let arguments = (ins + SpatTensor:$input + ); + + let results = (outs + SpatTensor:$output + ); + + let assemblyFormat = [{ + `(` $input `)` attr-dict `:` type($input) `->` type($output) + }]; +} + def SpatReluOp : SpatOp<"relu", []> { let summary = "Element-wise ReLU activation"; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp index c64b6b7..550a8a5 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp @@ -202,9 +202,9 @@ private: rewriter.clone(op, mapper); } - for (auto users : oldWeightedCompute->getUsers()) - if (auto funcRet = dyn_cast(users)) - funcRet.setOperand(0, newWeightedCompute.getResult(0)); + for (auto& use : llvm::make_early_inc_range(oldWeightedCompute->getUses())) + if (isa(use.getOwner())) + use.assign(newWeightedCompute.getResult(0)); oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute}); return {cast(newWeightedCompute), computeValueResults}; diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp index d81df66..3fa2413 100644 --- a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp @@ -413,7 +413,7 @@ struct ChannelBroadcastReceiveOpInterface outputTensor, rewriter.getI32IntegerAttr(numElements * elementSize), rewriter.getI32IntegerAttr(srcCoreId.value())) - .getOutput(); + .getOutput(); replaceOpWithBufferizedValues(rewriter, op, newValue); diff --git a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp index 32e5f43..5be11b0 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp +++ b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp @@ -146,6 +146,37 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override { + auto status = rewriteSubviewCopyLikeOp( + copyOp, + copyOp.getHostTarget(), + copyOp.getDeviceSource(), + copyOp.getHostTargetOffset(), + copyOp.getDeviceSourceOffset(), + copyOp.getSize(), + rewriter, + [&]( + MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { + pim::PimMemCopyDevToHostOp::create(rewriter, + copyOp.getLoc(), + resultType, + dst, + src, + rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), + rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), + rewriter.getI32IntegerAttr(static_cast(sliceBytes))); + }); + if (failed(status)) + return failure(); + + rewriter.replaceOp(copyOp, copyOp.getHostTarget()); + return success(); + } +}; + struct FoldConstantCoreSubviewPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -209,8 +240,10 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern( - patterns.getContext()); + patterns.add(patterns.getContext()); } } // namespace onnx_mlir diff --git a/validation/operations/README.md b/validation/operations/README.md index 4870159..3b1d979 100644 --- a/validation/operations/README.md +++ b/validation/operations/README.md @@ -85,6 +85,36 @@ python3 validation/operations/gen_tests.py | 4D | `sigmoid/4d` | [2,3,4,4] | [2,3,4,4] | Standalone NCHW Sigmoid | | After Gemm | `sigmoid/after_gemm` | [4,64] | [4,32] | Gemm + bias, then Sigmoid | +## Softmax + +| Test | Directory | Input | Output | Axis | Notes | +|--------------|--------------------------|-------------|-------------|------|---------------------------------| +| Basic | `softmax/basic` | [3,5] | [3,5] | 1 | Row-wise softmax over features | +| 3D last axis | `softmax/3d_last_axis` | [2,3,4] | [2,3,4] | 2 | Last-dimension normalization | +| Channel axis | `softmax/channel_axis` | [1,3,2,2] | [1,3,2,2] | 1 | NCHW channel-wise softmax | + +## Resize + +| Test | Directory | Input | Output | Mode | Notes | +|---------------------|-------------------------|-----------|-----------|---------|-----------------------------------------| +| Nearest 2x | `resize/nearest_2x` | [1,1,2,3] | [1,1,4,6] | nearest | NCHW upsampling with scales [1,1,2,2] | +| Non-uniform scales | `resize/non_uniform` | [1,1,2,3] | [1,1,6,6] | nearest | Different height/width scaling factors | +| Explicit sizes | `resize/with_sizes` | [1,1,2,3] | [1,1,3,5] | nearest | Sizes input used instead of scales | + +## Split + +| Test | Directory | Input | Outputs | Axis | Notes | +|-----------------|---------------------------|-------|----------------------|------|-------------------------------------| +| Basic | `split/basic` | [2,6] | [2,2], [2,4] | 1 | Two-way split with explicit sizes | +| Equal three-way | `split/equal_three_way` | [2,6] | [2,2], [2,2], [2,2] | 1 | Optional split input omitted | + +## Gather + +| Test | Directory | Input | Indices | Output | Axis | Notes | +|----------------------|--------------------------------|-------|---------|----------|------|--------------------------------| +| Axis 1 | `gather/axis1` | [3,4] | [2] | [3,2] | 1 | Select two columns | +| Axis 0 matrix indices| `gather/axis0_matrix_indices` | [4,3] | [2,2] | [2,2,3] | 0 | Gather rows with 2D indices | + ## Add | Test | Directory | Input(s) | Output | Notes | diff --git a/validation/operations/gather/axis0_matrix_indices/gather_axis0_matrix_indices.onnx b/validation/operations/gather/axis0_matrix_indices/gather_axis0_matrix_indices.onnx new file mode 100644 index 0000000..4119bad Binary files /dev/null and b/validation/operations/gather/axis0_matrix_indices/gather_axis0_matrix_indices.onnx differ diff --git a/validation/operations/gather/axis1/gather_axis1.onnx b/validation/operations/gather/axis1/gather_axis1.onnx new file mode 100644 index 0000000..c0d6ed4 Binary files /dev/null and b/validation/operations/gather/axis1/gather_axis1.onnx differ diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py index c146f8d..7e0309c 100644 --- a/validation/operations/gen_tests.py +++ b/validation/operations/gen_tests.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Generate ONNX test models for validating GEMM, Conv, Pooling, Relu, and ReduceMean implementations.""" +"""Generate ONNX test models for validating supported ONNX operations.""" import numpy as np import onnx @@ -473,6 +473,140 @@ def sigmoid_after_gemm(): save_model(model, "sigmoid/after_gemm", "sigmoid_after_gemm.onnx") +# --------------------------------------------------------------------------- +# Softmax tests +# --------------------------------------------------------------------------- + +def softmax_basic(): + """Softmax over the last dimension of a 2D tensor.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 5]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 5]) + node = helper.make_node("Softmax", ["X"], ["Y"], axis=1) + graph = helper.make_graph([node], "softmax_basic", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "softmax/basic", "softmax_basic.onnx") + + +def softmax_3d_last_axis(): + """Softmax over the last axis of a 3D tensor.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4]) + node = helper.make_node("Softmax", ["X"], ["Y"], axis=2) + graph = helper.make_graph([node], "softmax_3d_last_axis", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "softmax/3d_last_axis", "softmax_3d_last_axis.onnx") + + +def softmax_channel_axis(): + """Softmax over the channel axis of an NCHW tensor.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 2, 2]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 2, 2]) + node = helper.make_node("Softmax", ["X"], ["Y"], axis=1) + graph = helper.make_graph([node], "softmax_channel_axis", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "softmax/channel_axis", "softmax_channel_axis.onnx") + + +# --------------------------------------------------------------------------- +# Resize tests +# --------------------------------------------------------------------------- + +def resize_nearest_2x(): + """Resize an NCHW tensor with nearest-neighbor upsampling by a factor of 2.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 2, 3]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 4, 6]) + roi = numpy_helper.from_array(np.asarray([], dtype=np.float32), name="roi") + scales = numpy_helper.from_array(np.asarray([1.0, 1.0, 2.0, 2.0], dtype=np.float32), name="scales") + node = helper.make_node( + "Resize", ["X", "roi", "scales"], ["Y"], + mode="nearest", coordinate_transformation_mode="asymmetric", nearest_mode="floor") + graph = helper.make_graph([node], "resize_nearest_2x", [X], [Y], initializer=[roi, scales]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "resize/nearest_2x", "resize_nearest_2x.onnx") + + +def resize_nearest_non_uniform(): + """Resize an NCHW tensor with non-uniform nearest-neighbor scales.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 2, 3]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 6, 6]) + roi = numpy_helper.from_array(np.asarray([], dtype=np.float32), name="roi") + scales = numpy_helper.from_array(np.asarray([1.0, 1.0, 3.0, 2.0], dtype=np.float32), name="scales") + node = helper.make_node( + "Resize", ["X", "roi", "scales"], ["Y"], + mode="nearest", coordinate_transformation_mode="asymmetric", nearest_mode="floor") + graph = helper.make_graph([node], "resize_nearest_non_uniform", [X], [Y], initializer=[roi, scales]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "resize/non_uniform", "resize_non_uniform.onnx") + + +def resize_with_sizes(): + """Resize an NCHW tensor to explicit output sizes.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 2, 3]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 5]) + roi = numpy_helper.from_array(np.asarray([], dtype=np.float32), name="roi") + sizes = make_int64_initializer("sizes", [1, 1, 3, 5]) + node = helper.make_node( + "Resize", ["X", "roi", "", "sizes"], ["Y"], + mode="nearest", coordinate_transformation_mode="asymmetric", nearest_mode="floor") + graph = helper.make_graph([node], "resize_with_sizes", [X], [Y], initializer=[roi, sizes]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "resize/with_sizes", "resize_with_sizes.onnx") + + +# --------------------------------------------------------------------------- +# Split tests +# --------------------------------------------------------------------------- + +def split_basic(): + """Split a 2D tensor into two outputs along the feature axis.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 6]) + Y0 = helper.make_tensor_value_info("Y0", TensorProto.FLOAT, [2, 2]) + Y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 4]) + split = make_int64_initializer("split", [2, 4]) + node = helper.make_node("Split", ["X", "split"], ["Y0", "Y1"], axis=1) + graph = helper.make_graph([node], "split_basic", [X], [Y0, Y1], initializer=[split]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "split/basic", "split_basic.onnx") + + +def split_equal_three_way(): + """Split a 2D tensor evenly into three outputs.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 6]) + Y0 = helper.make_tensor_value_info("Y0", TensorProto.FLOAT, [2, 2]) + Y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 2]) + Y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 2]) + node = helper.make_node("Split", ["X"], ["Y0", "Y1", "Y2"], axis=1) + graph = helper.make_graph([node], "split_equal_three_way", [X], [Y0, Y1, Y2]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "split/equal_three_way", "split_equal_three_way.onnx") + + +# --------------------------------------------------------------------------- +# Gather tests +# --------------------------------------------------------------------------- + +def gather_axis1(): + """Gather selected columns from a 2D tensor.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 2]) + indices = make_int64_initializer("indices", [0, 2]) + node = helper.make_node("Gather", ["X", "indices"], ["Y"], axis=1) + graph = helper.make_graph([node], "gather_axis1", [X], [Y], initializer=[indices]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gather/axis1", "gather_axis1.onnx") + + +def gather_axis0_matrix_indices(): + """Gather rows using a 2D indices tensor.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 3]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 3]) + indices = make_int64_initializer("indices", [[0, 2], [3, 1]]) + node = helper.make_node("Gather", ["X", "indices"], ["Y"], axis=0) + graph = helper.make_graph([node], "gather_axis0_matrix_indices", [X], [Y], initializer=[indices]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gather/axis0_matrix_indices", "gather_axis0_matrix_indices.onnx") + + # --------------------------------------------------------------------------- # Add tests # --------------------------------------------------------------------------- @@ -599,55 +733,6 @@ def div_after_gemm(): save_model(model, "div/after_gemm", "div_after_gemm.onnx") -# --------------------------------------------------------------------------- -# ReduceMean tests -# --------------------------------------------------------------------------- - -def reducemean_basic(): - """ReduceMean over the feature dimension, preserving rank.""" - X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 8]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 1]) - node = helper.make_node("ReduceMean", ["X"], ["Y"], axes=[1], keepdims=1) - graph = helper.make_graph([node], "reducemean_basic", [X], [Y]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "reduce_mean/basic", "reduce_mean_basic.onnx") - - -def reducemean_keepdims_0(): - """ReduceMean over the feature dimension, dropping the reduced axis.""" - X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 8]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4]) - node = helper.make_node("ReduceMean", ["X"], ["Y"], axes=[1], keepdims=0) - graph = helper.make_graph([node], "reducemean_keepdims_0", [X], [Y]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "reduce_mean/keepdims_0", "reduce_mean_keepdims_0.onnx") - - -def reducemean_4d_spatial(): - """ReduceMean over H and W on an NCHW tensor.""" - X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 1, 1]) - node = helper.make_node("ReduceMean", ["X"], ["Y"], axes=[2, 3], keepdims=1) - graph = helper.make_graph([node], "reducemean_4d_spatial", [X], [Y]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "reduce_mean/4d_spatial", "reduce_mean_4d_spatial.onnx") - - -def reducemean_after_conv(): - """Conv followed by ReduceMean over the spatial dimensions.""" - rng = np.random.default_rng(62) - X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 1, 1]) - W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W") - B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B") - conv = helper.make_node("Conv", ["X", "W", "B"], ["C"], - kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0]) - reduce = helper.make_node("ReduceMean", ["C"], ["Y"], axes=[2, 3], keepdims=1) - graph = helper.make_graph([conv, reduce], "reducemean_after_conv", [X], [Y], initializer=[W, B]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "reduce_mean/after_conv", "reduce_mean_after_conv.onnx") - - # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- @@ -699,6 +784,24 @@ if __name__ == "__main__": sigmoid_4d() sigmoid_after_gemm() + print("\nGenerating Split tests:") + split_basic() + split_equal_three_way() + + print("\nGenerating Softmax tests:") + softmax_basic() + softmax_3d_last_axis() + softmax_channel_axis() + + print("\nGenerating Resize tests:") + resize_nearest_2x() + resize_nearest_non_uniform() + resize_with_sizes() + + print("\nGenerating Gather tests:") + gather_axis1() + gather_axis0_matrix_indices() + print("\nGenerating Add tests:") add_basic() add_broadcast_row() diff --git a/validation/operations/resize/nearest_2x/resize_nearest_2x.onnx b/validation/operations/resize/nearest_2x/resize_nearest_2x.onnx new file mode 100644 index 0000000..b4b8f5f Binary files /dev/null and b/validation/operations/resize/nearest_2x/resize_nearest_2x.onnx differ diff --git a/validation/operations/resize/non_uniform/resize_non_uniform.onnx b/validation/operations/resize/non_uniform/resize_non_uniform.onnx new file mode 100644 index 0000000..9003cee Binary files /dev/null and b/validation/operations/resize/non_uniform/resize_non_uniform.onnx differ diff --git a/validation/operations/resize/with_sizes/resize_with_sizes.onnx b/validation/operations/resize/with_sizes/resize_with_sizes.onnx new file mode 100644 index 0000000..227b586 Binary files /dev/null and b/validation/operations/resize/with_sizes/resize_with_sizes.onnx differ diff --git a/validation/operations/softmax/3d_last_axis/softmax_3d_last_axis.onnx b/validation/operations/softmax/3d_last_axis/softmax_3d_last_axis.onnx new file mode 100644 index 0000000..08ffdb8 Binary files /dev/null and b/validation/operations/softmax/3d_last_axis/softmax_3d_last_axis.onnx differ diff --git a/validation/operations/softmax/basic/softmax_basic.onnx b/validation/operations/softmax/basic/softmax_basic.onnx new file mode 100644 index 0000000..689313b Binary files /dev/null and b/validation/operations/softmax/basic/softmax_basic.onnx differ diff --git a/validation/operations/softmax/channel_axis/softmax_channel_axis.onnx b/validation/operations/softmax/channel_axis/softmax_channel_axis.onnx new file mode 100644 index 0000000..7bb2f90 Binary files /dev/null and b/validation/operations/softmax/channel_axis/softmax_channel_axis.onnx differ diff --git a/validation/operations/split/basic/split_basic.onnx b/validation/operations/split/basic/split_basic.onnx new file mode 100644 index 0000000..22ddd19 Binary files /dev/null and b/validation/operations/split/basic/split_basic.onnx differ diff --git a/validation/operations/split/equal_three_way/split_equal_three_way.onnx b/validation/operations/split/equal_three_way/split_equal_three_way.onnx new file mode 100644 index 0000000..03d4216 Binary files /dev/null and b/validation/operations/split/equal_three_way/split_equal_three_way.onnx differ