diff --git a/src/PIM/CMakeLists.txt b/src/PIM/CMakeLists.txt index e443e17..ad99bf5 100644 --- a/src/PIM/CMakeLists.txt +++ b/src/PIM/CMakeLists.txt @@ -43,7 +43,7 @@ add_onnx_mlir_library(OMPIMAccel PimOps OMONNXToSpatial OMSpatialToGraphviz - OMSpatialToPIM - OMPIMCommon + OMSpatialToPim + OMPimCommon MLIRTensorInferTypeOpInterfaceImpl ) diff --git a/src/PIM/Common/CMakeLists.txt b/src/PIM/Common/CMakeLists.txt index 5ca75f5..57c79f2 100644 --- a/src/PIM/Common/CMakeLists.txt +++ b/src/PIM/Common/CMakeLists.txt @@ -1,5 +1,5 @@ -add_onnx_mlir_library(OMPIMCommon - PIMCommon.cpp +add_onnx_mlir_library(OMPimCommon + PimCommon.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Common/PIMCommon.cpp b/src/PIM/Common/PimCommon.cpp similarity index 94% rename from src/PIM/Common/PIMCommon.cpp rename to src/PIM/Common/PimCommon.cpp index f648c57..f215a26 100644 --- a/src/PIM/Common/PIMCommon.cpp +++ b/src/PIM/Common/PimCommon.cpp @@ -3,10 +3,10 @@ #include #include -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Compiler/CompilerOptions.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; @@ -37,8 +37,7 @@ FailureOr getPimEntryFunc(ModuleOp moduleOp) { SmallVector entryPoints(moduleOp.getOps()); if (entryPoints.size() > 1) { - moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") - << entryPoints.size(); + moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size(); return failure(); } if (!entryPoints.empty()) { @@ -61,10 +60,9 @@ FailureOr getPimEntryFunc(ModuleOp moduleOp) { return mainGraphFunc; SmallVector nonExternalFuncs; - for (auto funcOp : moduleOp.getOps()) { + for (auto funcOp : moduleOp.getOps()) if (!funcOp.isExternal()) nonExternalFuncs.push_back(funcOp); - } if (nonExternalFuncs.size() == 1) return nonExternalFuncs.front(); @@ -72,11 +70,11 @@ FailureOr getPimEntryFunc(ModuleOp moduleOp) { return failure(); } -bool hasWeightAlways(Operation* op) { return op && op->getAttr(PIM_WEIGHT_ALWAYS_ATTR_NAME) != nullptr; } +bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; } void markWeightAlways(Operation* op) { assert(op && "expected valid op"); - op->setAttr(PIM_WEIGHT_ALWAYS_ATTR_NAME, UnitAttr::get(op->getContext())); + op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext())); } memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) { diff --git a/src/PIM/Common/PIMCommon.hpp b/src/PIM/Common/PimCommon.hpp similarity index 83% rename from src/PIM/Common/PIMCommon.hpp rename to src/PIM/Common/PimCommon.hpp index 153d89a..e8411a7 100644 --- a/src/PIM/Common/PIMCommon.hpp +++ b/src/PIM/Common/PimCommon.hpp @@ -1,8 +1,8 @@ #pragma once #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Operation.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" @@ -10,8 +10,8 @@ #include "src/Compiler/CompilerOptions.hpp" -const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate"; -inline constexpr llvm::StringRef PIM_WEIGHT_ALWAYS_ATTR_NAME = "weightAlways"; +const llvm::StringRef PimConstantShouldAllocateAttrName = "pim.constant.should_allocate"; +inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways"; namespace onnx_mlir { diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index a89eae1..4569da6 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -13,9 +13,9 @@ #include #include -#include "Common/PIMCommon.hpp" +#include "Common/PimCommon.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" -#include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp" +#include "Conversion/SpatialToPim/SpatialToPimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp index 69199a2..443da4c 100644 --- a/src/PIM/Compiler/PimCompilerUtils.cpp +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -34,7 +34,7 @@ void addPassesPim(OwningOpRef& module, } if (pimEmissionTarget >= EmitPim) { - pm.addPass(createSpatialToPIMPass()); + pm.addPass(createSpatialToPimPass()); // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Spatial lowered to Pim")); } diff --git a/src/PIM/Conversion/CMakeLists.txt b/src/PIM/Conversion/CMakeLists.txt index 27c58b8..dcbc3d9 100644 --- a/src/PIM/Conversion/CMakeLists.txt +++ b/src/PIM/Conversion/CMakeLists.txt @@ -1,3 +1,3 @@ add_subdirectory(ONNXToSpatial) add_subdirectory(SpatialToGraphviz) -add_subdirectory(SpatialToPIM) \ No newline at end of file +add_subdirectory(SpatialToPim) \ No newline at end of file diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index e105edd..85beb04 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -23,7 +23,7 @@ add_onnx_mlir_library(OMONNXToSpatial OMPimCompilerOptions OMONNXOps SpatialOps - OMPIMCommon + OMPimCommon ACCEL_INCLUDE_DIRS PRIVATE ${PIM_INCLUDE_PATH} diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp index f4acf18..b3c44c9 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp @@ -10,7 +10,7 @@ #include -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -91,7 +91,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType()); - auto aSlice = rewriter.create(loc, aSliceType, a, offsets, sizes, strides).getResult(); + auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult(); Value cSlice = c; if (hasC) { @@ -100,26 +100,27 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType()); - cSlice = rewriter.create(loc, cSliceType, c, offsets, sizes, strides).getResult(); + cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult(); } else assert("C should be a vector" && isVectorShape(getTensorShape(c))); } - auto gemvOp = rewriter.create(loc, - outRowType, - aSlice, - b, - cSlice, - gemmOp.getAlphaAttr(), - gemmOp.getBetaAttr(), - gemmOp.getTransAAttr(), - gemmOp.getTransBAttr()); + auto gemvOp = ONNXGemmOp::create(rewriter, + loc, + outRowType, + aSlice, + b, + cSlice, + gemmOp.getAlphaAttr(), + gemmOp.getBetaAttr(), + gemmOp.getTransAAttr(), + gemmOp.getTransBAttr()); gemvOps.push_back(gemvOp.getY()); } auto concatComputeOp = - rewriter.create(loc, gemmOp.getType(), SmallVector(), gemvOps); + spatial::SpatWeightedCompute::create(rewriter, loc, gemmOp.getType(), SmallVector(), gemvOps); auto* concatBlock = new Block(); for (auto gemvOp : gemvOps) @@ -128,8 +129,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, rewriter.setInsertionPointToStart(concatBlock); auto blockArgs = concatBlock->getArguments(); - auto concatOp = rewriter.create(loc, /*axis=*/0, blockArgs); - rewriter.create(loc, concatOp.getResult()); + auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, blockArgs); + spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); rewriter.replaceOp(gemmOp, concatComputeOp); return success(); @@ -170,25 +171,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, if (transA) { auto aShape = aType.getShape(); auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType()); - a = rewriter.create(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})); + a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})); } if (transB) { auto bShape = bType.getShape(); auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); - b = rewriter.create(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); + b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); } if (alpha != 1.0f) { auto alphaTensorType = RankedTensorType::get({1, 1}, cast(a.getType()).getElementType()); auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha}); - auto alphaTensor = rewriter.create(gemmLoc, alphaTensorType, alphaTensorValue); - a = rewriter.create(gemmLoc, a.getType(), a, alphaTensor); + auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue); + a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor); } if (hasC && beta != 1.0f) { auto betaTensorType = RankedTensorType::get({1, 1}, cast(c.getType()).getElementType()); auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta}); - auto betaTensor = rewriter.create(gemmLoc, betaTensorType, betaTensorValue); - c = rewriter.create(gemmLoc, c.getType(), c, betaTensor); + auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue); + c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor); } auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); @@ -235,7 +236,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, weights.push_back(bTiles[outSliceId][coreId][aSliceId]); auto computeOp = - rewriter.create(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]); + spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId]); auto* computeBlock = new Block(); for (auto aHSlice : aHSlices[coreId]) @@ -248,11 +249,11 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, vmmOutputs.reserve(computeArgs.size()); for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++) vmmOutputs.push_back( - rewriter.create(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId])); + spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId])); assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); Value partialVmmSum = sumTensors(vmmOutputs, rewriter); - rewriter.create(gemmLoc, partialVmmSum); + spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); rewriter.setInsertionPointAfter(computeOp); partialResults.push_back(computeOp.getResult(0)); @@ -264,7 +265,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, } auto reduceComputeOp = - rewriter.create(gemmLoc, currOutHSliceType, SmallVector(), partialResults); + spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, SmallVector(), partialResults); auto* reduceBlock = new Block(); for (auto partialResult : partialResults) @@ -274,14 +275,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, auto blockArgs = reduceBlock->getArguments(); Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter); - rewriter.create(gemmLoc, outHSlice); + spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice); rewriter.setInsertionPointAfter(reduceComputeOp); outHSlices.push_back(reduceComputeOp.getResult(0)); } auto concatComputeOp = - rewriter.create(gemmLoc, gemmOp.getType(), SmallVector(), outHSlices); + spatial::SpatWeightedCompute::create(rewriter, gemmLoc, gemmOp.getType(), SmallVector(), outHSlices); auto* concatBlock = new Block(); for (auto outHSlice : outHSlices) @@ -290,8 +291,8 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, rewriter.setInsertionPointToStart(concatBlock); auto blockArgs = concatBlock->getArguments(); - auto concatOp = rewriter.create(gemmLoc, /*axis=*/1, blockArgs); - rewriter.create(gemmLoc, concatOp.getResult()); + auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs); + spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult()); rewriter.replaceOp(gemmOp, concatComputeOp); return success(); @@ -335,9 +336,9 @@ LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector(loc, scalarTensorType, a, b); }, + [&](Value a, Value b) { return spatial::SpatVAddOp::create(rewriter, loc, scalarTensorType, a, b); }, /* preprocess = */ - [&](Value a) { return rewriter.create(loc, scalarTensorType, a); }, + [&](Value a) { return spatial::SpatSumOp::create(rewriter, loc, scalarTensorType, a); }, [&](Value softmaxDivisor) { // Signal that this is the compute with the softmax divisor auto computeOp = cast(softmaxDivisor.getDefiningOp()->getParentOp()); @@ -345,7 +346,7 @@ LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector(loc, softmaxChannel, softmaxDivisor); + spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, softmaxChannel, softmaxDivisor); /* * softmaxDividend = onnx.exp (...) @@ -395,7 +396,7 @@ LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector(loc, scalarTensorType, softmaxChannel); + divisor = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, loc, scalarTensorType, softmaxChannel); } // Walk the chain of operations until we find the ONNXExpOp: this is @@ -405,7 +406,7 @@ LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVectorgetOperand(computeToDivideOpAndResNum.second)); rewriter.setInsertionPoint(yieldOp); - Value newOutputTile = rewriter.create(loc, oldOutputTile.getType(), oldOutputTile, divisor); + Value newOutputTile = spatial::SpatVSDivOp::create(rewriter, loc, oldOutputTile.getType(), oldOutputTile, divisor); auto yieldOperandNum = yieldOp->getNumOperands(); yieldOp->insertOperands(yieldOperandNum, newOutputTile); diff --git a/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp b/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp index c440d80..cc0b3c5 100644 --- a/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp @@ -15,7 +15,7 @@ #include #include -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" @@ -111,15 +111,15 @@ Value applyReducePatternNew(SmallVector& valuesToReduce, // 1. Add a channel before the first computeOp rewriter.setInsertionPoint(firstCompute); - auto channel = rewriter.create(loc, channelType); + auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType); // 2. Add a sendOp after the first value rewriter.setInsertionPointAfterValue(firstValue); - rewriter.create(loc, channel, firstValue); + spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue); // 3. Add a receiveOp after the second value rewriter.setInsertionPointAfterValue(secondValue); - auto receivedValue = rewriter.create(loc, secondValue.getType(), channel); + auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel); // 4. Apply reduction between second value and received value rewriter.setInsertionPointAfterValue(receivedValue); @@ -188,13 +188,14 @@ Value postProcessPoolingWindow(ConversionPatternRewriter& rew // directly under func.func (i.e. alongside ComputeOps) auto computeOp = cast(valueToDivide.getDefiningOp()->getParentOp()); rewriter.setInsertionPoint(computeOp); - auto divisorValue = rewriter.create(loc, - scalarTensor, - rewriter.getI64IntegerAttr(divisorNumber), - /* should_allocate = */ rewriter.getBoolAttr(true)); + auto divisorValue = spatial::SpatConstantOp::create(rewriter, + loc, + scalarTensor, + rewriter.getI64IntegerAttr(divisorNumber), + /* should_allocate = */ rewriter.getBoolAttr(true)); rewriter.setInsertionPointAfterValue(valueToDivide); - return rewriter.create(loc, valueToDivide.getType(), valueToDivide, divisorValue); + return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue); } template @@ -257,17 +258,18 @@ struct PoolingBaseConverter : public OpConversionPattern { if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp()) { Location tileLoc = extractSliceOp.getLoc(); - auto tempComputeOp = rewriter.create(tileLoc, - extractSliceOp.getResultType(), - /* xbarWeights =*/ValueRange(), - extractSliceOp.getResult()); + auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter, + tileLoc, + extractSliceOp.getResultType(), + /* xbarWeights =*/ValueRange(), + extractSliceOp.getResult()); Block* tempComputeOpBlock = new Block(); tempComputeOp.getBody().push_back(tempComputeOpBlock); auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc); rewriter.setInsertionPointToStart(tempComputeOpBlock); - rewriter.create(tileLoc, tempComputeOpBlockArg); + spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg); rewriter.setInsertionPointAfter(tempComputeOp); inputTiles[t][x][y] = tempComputeOp.getResult(0); } @@ -356,7 +358,7 @@ struct PoolingBaseConverter : public OpConversionPattern { Value reducedWithinCompute = applyReducePatternNew( valuesToPool, rewriter, - [&](const Value lhs, const Value rhs) { return rewriter.create(loc, lhs.getType(), lhs, rhs); }, + [&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); }, nullptr, postProcessFn); @@ -369,16 +371,16 @@ struct PoolingBaseConverter : public OpConversionPattern { // Create a new channel before the computeOp rewriter.setInsertionPoint(computeOpOfReduced); auto reduceChannel = - rewriter.create(loc, spatial::SpatChannelType::get(rewriter.getContext())); + spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext())); // Send value through the channel rewriter.setInsertionPointAfterValue(reducedWithinCompute); - rewriter.create(loc, reduceChannel, reducedWithinCompute); + spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute); // Receive after the computeOp rewriter.setInsertionPointAfter(computeOpOfReduced); auto receivedValue = - rewriter.create(loc, reducedWithinCompute.getType(), reduceChannel); + spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel); outputTiles[outTile][outX][outY] = receivedValue; } diff --git a/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp index 5906242..e55693d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp @@ -63,16 +63,17 @@ struct ReduceMeanConversionPattern : public OpConversionPattern(reduceMean.getLoc(), - resultType, - inputTensor, - /*auto_pad=*/"NOTSET", - /*ceil_mode=*/0, - /*count_include_pad=*/1, - dilations, - /*kernel_shape=*/kernelShape, - /*pads=*/pads, - /*strides=*/strides); + auto averagePool = ONNXAveragePoolOp::create(rewriter, + reduceMean.getLoc(), + resultType, + inputTensor, + /*auto_pad=*/"NOTSET", + /*ceil_mode=*/0, + /*count_include_pad=*/1, + dilations, + /*kernel_shape=*/kernelShape, + /*pads=*/pads, + /*strides=*/strides); // Replace the ONNXReduceMeanV13Op with the ONNXAveragePoolOp. rewriter.replaceOp(reduceMean, averagePool.getResult()); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td index a4e0bc9..f83d758 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td @@ -29,7 +29,7 @@ def matMulToGemmPattern : Pat< (ONNXMatMulOp:$matmulres $A, $B), ( ONNXGemmOp $A, $B, - /* C = */ (NativeCodeCall<"$_builder.create($_loc, cast(matmulres.getY().getType()).getShape(), cast(matmulres.getY().getType()).getElementType());">), + /* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast(matmulres.getY().getType()).getShape(), cast(matmulres.getY().getType()).getElementType());">), /* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), /* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">), /* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">), @@ -70,4 +70,4 @@ def removeFlattenSameShapePattern : Pat< [(HaveSameStaticShape $flattenOp, $A)] >; // Add closing parenthesis here -#endif // ONNX_TO_SPATIAL \ No newline at end of file +#endif // ONNX_TO_SPATIAL diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp index 6d8a849..e633c34 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp @@ -47,7 +47,7 @@ SmallVector sliceTensor( if (i == numSlices - 1 && lastSliceSize != 0) sizes[axis] = rewriter.getIndexAttr(lastSliceSize); - Value slice = rewriter.create(loc, tensorToSlice, offsets, sizes, strides); + Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides); slices.push_back(slice); } @@ -100,11 +100,11 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr int64_t shape[2] = {1, length}; Type type = oldType.cloneWith(ArrayRef(shape), elementType); - auto zero = rewriter.create(loc, 0).getResult(); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); SmallVector index(oldType.getRank(), zero); - auto elementValue = rewriter.create(loc, scalarToBroadcast, index).getResult(); + auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult(); - return rewriter.create(loc, type, elementValue); + return tensor::SplatOp::create(rewriter, loc, type, elementValue); } Value sumTensors(ArrayRef tensors, ConversionPatternRewriter& rewriter) { @@ -122,7 +122,7 @@ Value sumTensors(ArrayRef tensors, ConversionPatternRewriter& rewriter) { Value a = (*currTensors)[i]; Value b = (*currTensors)[i + 1]; rewriter.setInsertionPointAfterValue(b); - auto addedValue = rewriter.create(a.getLoc(), a.getType(), a, b); + auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b); nextTensors->push_back(addedValue); } if (currTensors->size() % 2 == 1) @@ -137,10 +137,10 @@ Value sumTensors(ArrayRef tensors, ConversionPatternRewriter& rewriter) { Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) { switch (mapOp) { case MapOperations::None: assert(false && "Invalid map operation during map operation creation."); - case MapOperations::ONNXSoftmaxOp: return rewriter.create(input.getLoc(), input.getType(), input); - case MapOperations::ONNXReluOp: return rewriter.create(input.getLoc(), input.getType(), input); - case MapOperations::ONNXLeakyReluOp: return rewriter.create(input.getLoc(), input.getType(), input); - case MapOperations::ONNXExpOp: return rewriter.create(input.getLoc(), input.getType(), input); + case MapOperations::ONNXSoftmaxOp: return ONNXSoftmaxOp::create(rewriter, input.getLoc(), input.getType(), input); + case MapOperations::ONNXReluOp: return ONNXReluOp::create(rewriter, input.getLoc(), input.getType(), input); + case MapOperations::ONNXLeakyReluOp: return ONNXLeakyReluOp::create(rewriter, input.getLoc(), input.getType(), input); + case MapOperations::ONNXExpOp: return ONNXExpOp::create(rewriter, input.getLoc(), input.getType(), input); } } @@ -201,7 +201,7 @@ void tileImageTensorByChannel(Value imageTensor, offsets[2] = rewriter.getIndexAttr(x); offsets[3] = rewriter.getIndexAttr(y); - tiles[i][x][y] = rewriter.create(loc, imageTensor, offsets, sizes, strides); + tiles[i][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, imageTensor, offsets, sizes, strides); } } } @@ -225,7 +225,7 @@ Value createImgConcatOp(SmallVector>>& outputTile for (size_t outTile = 0; outTile < outputTileCount; outTile++) tilesToConcat.push_back(outputTiles[outTile][outX][outY]); - return rewriter.create(loc, outputType, tilesToConcat); + return spatial::SpatImgConcatOp::create(rewriter, loc, outputType, tilesToConcat); } LogicalResult @@ -271,7 +271,7 @@ Value createExtractSliceImg(Value valToSlice, offsets[2] = rewriter.getIndexAttr(x); offsets[3] = rewriter.getIndexAttr(y); - return rewriter.create(valToSlice.getLoc(), valToSlice, offsets, sizes, strides); + return tensor::ExtractSliceOp::create(rewriter, valToSlice.getLoc(), valToSlice, offsets, sizes, strides); } Value indexImgValue(Value v, @@ -384,7 +384,7 @@ void resolveInputTensorTilesBlockArg(Value wholeInputTensor, offsets[2] = rewriter.getIndexAttr(x); offsets[3] = rewriter.getIndexAttr(y); - inputTiles[t][x][y] = rewriter.create(loc, wholeInputTensor, offsets, sizes, strides); + inputTiles[t][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, wholeInputTensor, offsets, sizes, strides); } } } @@ -452,7 +452,7 @@ LogicalResult handleFlattenLikeOp(SmallVector>& inputTiles, SmallVector newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)}; auto shapeType = RankedTensorType::get({static_cast(newShapeVals.size())}, rewriter.getI64Type()); Value shapeTensor = - rewriter.create(reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals)); + arith::ConstantOp::create(rewriter, reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals)); auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType()); auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 17ee2f1..0bd82d0 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -9,7 +9,7 @@ #include -#include "Common/PIMCommon.hpp" +#include "Common/PimCommon.hpp" #include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp index 3187300..36d0573 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp @@ -272,8 +272,8 @@ void SpatialReducer::updateResultsOfCompute(mlir::Operation* computeOp) { // Create a new ComputeOp with the new result type, but same operands rewriter.setInsertionPoint(oldComputeOp); - auto newComputeOp = rewriter.create( - oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs()); + auto newComputeOp = spatial::SpatWeightedCompute::create( + rewriter, oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs()); newComputeOp.getBody().takeBody(oldComputeOp.getBody()); @@ -333,14 +333,14 @@ OpAndResNum SpatialReducer::applyAddMapReduction(llvm::SmallVector(a.getLoc(), a.getType(), a, biasTile); + mapOperand = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, biasTile); return createMapOperation(rewriter, mapOp, mapOperand); }; } return this->applyReducePattern( computeOps, - [&](mlir::Value a, mlir::Value b) { return rewriter.create(a.getLoc(), a.getType(), a, b); }, + [&](mlir::Value a, mlir::Value b) { return spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b); }, /* preprocess = */ nullptr, postprocessing); } diff --git a/src/PIM/Conversion/SpatialToGraphviz/CMakeLists.txt b/src/PIM/Conversion/SpatialToGraphviz/CMakeLists.txt index 1e849a4..cf79a3c 100644 --- a/src/PIM/Conversion/SpatialToGraphviz/CMakeLists.txt +++ b/src/PIM/Conversion/SpatialToGraphviz/CMakeLists.txt @@ -5,7 +5,7 @@ add_onnx_mlir_library(OMSpatialToGraphviz LINK_LIBS PUBLIC OMCompilerOptions - OMPIMCommon + OMPimCommon OMONNXOps SpatialOps diff --git a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp index d331444..d510e01 100644 --- a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp +++ b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp @@ -10,7 +10,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Format.h" -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/SpatialToPIM/CMakeLists.txt b/src/PIM/Conversion/SpatialToPIM/CMakeLists.txt deleted file mode 100644 index 0fcd649..0000000 --- a/src/PIM/Conversion/SpatialToPIM/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS SpatialToPIM.td) -mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") -add_public_tablegen_target(SpatialToPIMIncGen) - -add_onnx_mlir_library(OMSpatialToPIM - SpatialToPIMPass.cpp - SpatialToPIMCommon.cpp - - DEPENDS - SpatialToPIMIncGen - - LINK_LIBS PUBLIC - OMCompilerOptions - OMPIMCommon - SpatialOps - PimOps - - ACCEL_INCLUDE_DIRS PRIVATE - ${PIM_INCLUDE_PATH} -) diff --git a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt new file mode 100644 index 0000000..33b42d8 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt @@ -0,0 +1,20 @@ +set(LLVM_TARGET_DEFINITIONS SpatialToPim.td) +mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") +add_public_tablegen_target(SpatialToPimIncGen) + +add_onnx_mlir_library(OMSpatialToPim + SpatialToPimPass.cpp + SpatialToPimCommon.cpp + + DEPENDS + SpatialToPimIncGen + + LINK_LIBS PUBLIC + OMCompilerOptions + OMPimCommon + SpatialOps + PimOps + + ACCEL_INCLUDE_DIRS PRIVATE + ${PIM_INCLUDE_PATH} +) diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIM.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td similarity index 100% rename from src/PIM/Conversion/SpatialToPIM/SpatialToPIM.td rename to src/PIM/Conversion/SpatialToPim/SpatialToPim.td diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimCommon.cpp similarity index 98% rename from src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp rename to src/PIM/Conversion/SpatialToPim/SpatialToPimCommon.cpp index 958e6c8..5c17eaf 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimCommon.cpp @@ -5,7 +5,7 @@ #include #include -#include "SpatialToPIMCommon.hpp" +#include "SpatialToPimCommon.hpp" using namespace llvm; using namespace mlir; diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp similarity index 100% rename from src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp rename to src/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp similarity index 86% rename from src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp rename to src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 13fd2f7..2bf3a49 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -18,8 +18,8 @@ #include #include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" @@ -33,15 +33,15 @@ namespace onnx_mlir { namespace { -#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc" -struct SpatialToPIMPass : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass) +struct SpatialToPimPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass) StringRef getArgument() const override { return "convert-spatial-to-pim"; } StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } - SpatialToPIMPass() = default; - SpatialToPIMPass(const SpatialToPIMPass& pass) {} + SpatialToPimPass() = default; + SpatialToPimPass(const SpatialToPimPass& pass) {} void runOnOperation() final; @@ -76,7 +76,7 @@ private: } // namespace -void SpatialToPIMPass::runOnOperation() { +void SpatialToPimPass::runOnOperation() { coreId = 1; ModuleOp moduleOp = getOperation(); MLIRContext* ctx = moduleOp.getContext(); @@ -132,7 +132,7 @@ void SpatialToPIMPass::runOnOperation() { dumpModule(moduleOp, "pim"); } -void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) { +void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) { Location loc = computeOp->getLoc(); auto& block = computeOp.getRegion().front(); @@ -180,13 +180,14 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR // Store to global memory Value outputTensor = outputTensors[resultIndexInReturn]; rewriter.setInsertionPointAfterValue(yieldValue); - rewriter.create(loc, - outputTensor.getType(), - outputTensor, - yieldValue, - rewriter.getI32IntegerAttr(offset), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(numElements * elementSize)); + PimMemCopyDevToHostOp::create(rewriter, + loc, + outputTensor.getType(), + outputTensor, + yieldValue, + rewriter.getI32IntegerAttr(offset), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(numElements * elementSize)); continue; } @@ -211,14 +212,14 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR // Store to global memory Value outputTensor = outputTensors[concatIndexInReturn]; rewriter.setInsertionPointAfterValue(yieldValue); - rewriter.create( - loc, - outputTensor.getType(), - outputTensor, - yieldValue, - rewriter.getI32IntegerAttr(offset), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize)); + PimMemCopyDevToHostOp::create(rewriter, + loc, + outputTensor.getType(), + outputTensor, + yieldValue, + rewriter.getI32IntegerAttr(offset), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize)); continue; } } @@ -230,7 +231,7 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR // 1. Create a new ChannelOp rewriter.setInsertionPoint(computeOp); auto channelType = spatial::SpatChannelType::get(computeOp.getContext()); - auto channelOp = rewriter.create(loc, channelType); + auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType); // 2. Receive value through the channel // If this result is used by more than one user, then use a "Broadcast" @@ -244,9 +245,9 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR // 3. Send the value through the channel rewriter.setInsertionPointAfterValue(yieldValue); if (useBroadcastOp) - rewriter.create(loc, channelOp, yieldValue); + spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, channelOp, yieldValue); else - rewriter.create(loc, channelOp, yieldValue); + spatial::SpatChannelSendOp::create(rewriter, loc, channelOp, yieldValue); } // Use `HaltOp` instead of `YieldOp` @@ -255,17 +256,17 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR // Replace `spat.compute` with `pim.core` rewriter.setInsertionPointAfter(computeOp); - auto coreOp = rewriter.create(loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++)); + auto coreOp = PimCoreOp::create(rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++)); auto& coreOpBlocks = coreOp.getBody().getBlocks(); block.eraseArguments(0, block.getNumArguments()); coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); Block* tempComputeBlock = new Block(); computeOp.getBody().push_back(tempComputeBlock); rewriter.setInsertionPointToEnd(tempComputeBlock); - rewriter.create(computeOp.getLoc()); + PimHaltOp::create(rewriter, computeOp.getLoc()); } -void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { +void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void { auto* definingOp = value.getDefiningOp(); if (!definingOp) @@ -302,14 +303,14 @@ void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I SmallVector sizes = {oldShapeZeroAttr, oldShapeOneAttr}; SmallVector strides = {oneAttr, oneAttr}; rewriter.setInsertionPointAfter(vmmOp); - auto sliceOp = rewriter.create(vmmOp.getLoc(), resultTensor, offsets, sizes, strides); + auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides); SmallPtrSet exceptions = {vmmOp, sliceOp}; resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions); } }); } -void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { +void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { outputTensors.reserve(returnOp->getNumOperands()); rewriter.setInsertionPointToStart(returnOp->getBlock()); for (auto returnValue : returnOp->getOperands()) { @@ -326,7 +327,7 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew } } -void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { +void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) { @@ -335,9 +336,10 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace)); - auto deviceTensor = rewriter.create(loc, tensorType.getShape(), elementType); + auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); - auto memCopyHostToDevOp = rewriter.create( + auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create( + rewriter, loc, tensorType, deviceTensor, @@ -362,7 +364,8 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func Block& block = funcOp.getBody().front(); rewriter.setInsertionPoint(&block.front()); - auto toTensorOp = rewriter.create(loc, tensorArgType, memRefArg, rewriter.getUnitAttr()); + auto toTensorOp = + bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr()); inputTensors.push_back(toTensorOp); tensorArg.replaceAllUsesWith(toTensorOp); @@ -415,7 +418,7 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func rewriter.eraseOp(sliceOp); } -void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, +void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, unsigned int argIndex, spatial::SpatChannelNewOp& channel, Type& tensorType, @@ -431,14 +434,14 @@ void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); Value receivedValue; if (useBroadcastOp) - receivedValue = rewriter.create(computeOp.getLoc(), tensorType, channel); + receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel); else - receivedValue = rewriter.create(computeOp.getLoc(), tensorType, channel); + receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel); blockArg.replaceAllUsesWith(receivedValue); } -void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp, +void SpatialToPimPass::addReceiveOps(Value& channelSourceOp, spatial::SpatChannelNewOp& channel, Type& channelTensorType, bool& useBroadcastOp, @@ -495,7 +498,7 @@ void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp, } } -void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { +void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { for (auto it : llvm::enumerate(returnOp.getOperands())) { Operation* returnOperand = it.value().getDefiningOp(); @@ -514,7 +517,7 @@ void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri } } -void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) { +void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) { auto channel = cast(receiveOp.getChannel().getDefiningOp()); @@ -543,6 +546,6 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I } } -std::unique_ptr createSpatialToPIMPass() { return std::make_unique(); } +std::unique_ptr createSpatialToPimPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPatterns.hpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPatterns.hpp similarity index 100% rename from src/PIM/Conversion/SpatialToPIM/SpatialToPIMPatterns.hpp rename to src/PIM/Conversion/SpatialToPim/SpatialToPimPatterns.hpp diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt b/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt index 1f17eec..277306f 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt @@ -13,7 +13,7 @@ add_onnx_mlir_library(OMPimBufferization PimBufferizationIncGen LINK_LIBS PUBLIC - OMPIMCommon + OMPimCommon PimOps ACCEL_INCLUDE_DIRS PRIVATE diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index e5badef..996eed4 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -5,7 +5,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" -#include "Common/PIMCommon.hpp" +#include "Common/PimCommon.hpp" #include "Compiler/PimCodeGen.hpp" #include "Dialect/Pim/PimOps.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp" diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp index e888e77..cc56c90 100644 --- a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp @@ -17,7 +17,7 @@ #include -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" @@ -34,7 +34,7 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType()); // Alloc an output memref - return rewriter.create(loc, memrefResultType); + return memref::AllocOp::create(rewriter, loc, memrefResultType); } const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id"); @@ -134,7 +134,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa memrefOperands.push_back(outputTensor); - Value newValue = rewriter.create(op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes(); + Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes(); replaceOpWithBufferizedValues(rewriter, op, newValue); @@ -169,11 +169,13 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod // Alloc an output memref Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); - Value newValue = - rewriter - .create( - op->getLoc(), outputTensor.getType(), cast(op).getWeightIndexAttr(), memrefOperand, outputTensor) - .getOutRes(); + Value newValue = ToTy::create(rewriter, + op->getLoc(), + outputTensor.getType(), + cast(op).getWeightIndexAttr(), + memrefOperand, + outputTensor) + .getOutRes(); replaceOpWithBufferizedValues(rewriter, op, newValue); @@ -213,12 +215,12 @@ struct ChannelReceiveOpInterface if (failed(srcCoreId)) return failure(); - Value newValue = rewriter - .create(op->getLoc(), - outputTensor.getType(), - outputTensor, - rewriter.getI32IntegerAttr(numElements * elementSize), - rewriter.getI32IntegerAttr(srcCoreId.value())) + Value newValue = pim::PimReceiveOp::create(rewriter, + op->getLoc(), + outputTensor.getType(), + outputTensor, + rewriter.getI32IntegerAttr(numElements * elementSize), + rewriter.getI32IntegerAttr(srcCoreId.value())) .getOut(); replaceOpWithBufferizedValues(rewriter, op, newValue); @@ -324,13 +326,14 @@ struct ChannelBroadcastReceiveOpInterface } rewriter.setInsertionPoint(op); - auto memCopyHostToDevOp = rewriter.create(op->getLoc(), - outputTensor.getType(), - outputTensor, - bufferAllocation, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(outputSize)); + auto memCopyHostToDevOp = pim::PimMemCopyHostToDevOp::create(rewriter, + op->getLoc(), + outputTensor.getType(), + outputTensor, + bufferAllocation, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(outputSize)); replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst()); @@ -395,13 +398,14 @@ struct ChannelBroadcastSendOpInterface auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8; rewriter.setInsertionPoint(op); - rewriter.create(op->getLoc(), - bufferAllocation.getType(), - bufferAllocation, - srcMemRef, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(sizeInBytes)); + pim::PimMemCopyDevToHostOp::create(rewriter, + op->getLoc(), + bufferAllocation.getType(), + bufferAllocation, + srcMemRef, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(sizeInBytes)); rewriter.eraseOp(op); return success(); } @@ -481,14 +485,15 @@ struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel(op).getXKernelPositionsAttr(); auto yKernelPositions = cast(op).getYKernelPositionsAttr(); - Value bufferized = rewriter.create(op->getLoc(), - outputTensor.getType(), - weightIndices, - xKernelPositions, - yKernelPositions, - *inputBuffer, - outputTensor, - accumBuffer); + Value bufferized = pim::PimApplyFiltersOp::create(rewriter, + op->getLoc(), + outputTensor.getType(), + weightIndices, + xKernelPositions, + yKernelPositions, + *inputBuffer, + outputTensor, + accumBuffer); // Replace the operation with the bufferized value. replaceOpWithBufferizedValues(rewriter, op, bufferized); diff --git a/src/PIM/Pass/EmitPimJsonPass.cpp b/src/PIM/Pass/EmitPimJsonPass.cpp index fc20365..baa72d2 100644 --- a/src/PIM/Pass/EmitPimJsonPass.cpp +++ b/src/PIM/Pass/EmitPimJsonPass.cpp @@ -1,6 +1,6 @@ #include "mlir/Pass/Pass.h" -#include "Common/PIMCommon.hpp" +#include "Common/PimCommon.hpp" #include "Compiler/PimCodeGen.hpp" using namespace mlir; diff --git a/src/PIM/Pass/PimFoldHostConstantsPass.cpp b/src/PIM/Pass/PimFoldHostConstantsPass.cpp index f873c09..18a22f0 100644 --- a/src/PIM/Pass/PimFoldHostConstantsPass.cpp +++ b/src/PIM/Pass/PimFoldHostConstantsPass.cpp @@ -12,7 +12,7 @@ #include -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; diff --git a/src/PIM/Pass/PimHostVerificationPass.cpp b/src/PIM/Pass/PimHostVerificationPass.cpp index b2bedd5..c82f9ef 100644 --- a/src/PIM/Pass/PimHostVerificationPass.cpp +++ b/src/PIM/Pass/PimHostVerificationPass.cpp @@ -5,7 +5,7 @@ #include "llvm/ADT/STLExtras.h" -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" diff --git a/src/PIM/Pass/PimPasses.hpp b/src/PIM/Pass/PimPasses.hpp index 781c073..43fa213 100644 --- a/src/PIM/Pass/PimPasses.hpp +++ b/src/PIM/Pass/PimPasses.hpp @@ -11,7 +11,7 @@ std::unique_ptr createONNXToSpatialPass(); std::unique_ptr createSpatialToGraphvizPass(); -std::unique_ptr createSpatialToPIMPass(); +std::unique_ptr createSpatialToPimPass(); std::unique_ptr createBufferizePimPass(); diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp index 997370f..6188c3a 100644 --- a/src/PIM/PimAccelerator.cpp +++ b/src/PIM/PimAccelerator.cpp @@ -71,7 +71,7 @@ void PimAccelerator::registerPasses(int optLevel) const { LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n"); registerPass(createONNXToSpatialPass); registerPass(createSpatialToGraphvizPass); - registerPass(createSpatialToPIMPass); + registerPass(createSpatialToPimPass); registerPass(createBufferizePimPass); registerPass(createPimFoldHostConstantsPass); registerPass(createPimHostVerificationPass);