diff --git a/src/PIM/Common/PimCommon.hpp b/src/PIM/Common/PimCommon.hpp index 0b9fadd..f31492e 100644 --- a/src/PIM/Common/PimCommon.hpp +++ b/src/PIM/Common/PimCommon.hpp @@ -6,13 +6,11 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" -#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "src/Compiler/CompilerOptions.hpp" -const llvm::StringRef PimConstantShouldAllocateAttrName = "pim.constant.should_allocate"; inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways"; namespace onnx_mlir { diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 20b71c3..192d15a 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -194,45 +194,45 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_ void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const { emitMemCopyOp("ld", - memory.getValueAddress(loadOp.getDeviceDst()), - loadOp.getDeviceDstOffset(), - memory.getValueAddress(loadOp.getHostSrc()), - loadOp.getHostSrcOffset(), + memory.getValueAddress(loadOp.getDeviceTarget()), + loadOp.getDeviceTargetOffset(), + memory.getValueAddress(loadOp.getHostSource()), + loadOp.getHostSourceOffset(), loadOp.getSize()); } void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const { emitMemCopyOp("st", - memory.getValueAddress(storeOp.getHostDst()), - storeOp.getHostDstOffset(), - memory.getValueAddress(storeOp.getDeviceSrc()), - storeOp.getDeviceSrcOffset(), + memory.getValueAddress(storeOp.getHostTarget()), + storeOp.getHostTargetOffset(), + memory.getValueAddress(storeOp.getDeviceSource()), + storeOp.getDeviceSourceOffset(), storeOp.getSize()); } void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const { emitMemCopyOp("lmv", - memory.getValueAddress(lmvOp.getDst()), - lmvOp.getDstOffset(), - memory.getValueAddress(lmvOp.getSrc()), - lmvOp.getSrcOffset(), + memory.getValueAddress(lmvOp.getTarget()), + lmvOp.getTargetOffset(), + memory.getValueAddress(lmvOp.getSource()), + lmvOp.getSourceOffset(), lmvOp.getSize(), "len"); } void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const { emitCommunicationOp( - "recv", memory.getValueAddress(receiveOp.getDst()), receiveOp.getSrcCoreId(), receiveOp.getSize()); + "recv", memory.getValueAddress(receiveOp.getOutputBuffer()), receiveOp.getSourceCoreId(), receiveOp.getSize()); } void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) const { - emitCommunicationOp("send", memory.getValueAddress(sendOp.getSrc()), sendOp.getTargetCoreId(), sendOp.getSize()); + emitCommunicationOp("send", memory.getValueAddress(sendOp.getInput()), sendOp.getTargetCoreId(), sendOp.getSize()); } template void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) { emitMvmOp( - mvmId, memory.getValueAddress(mvmLikeOp.getOutBuf()), 0, memory.getValueAddress(mvmLikeOp.getVectorInput()), 0); + mvmId, memory.getValueAddress(mvmLikeOp.getOutputBuffer()), 0, memory.getValueAddress(mvmLikeOp.getInput()), 0); // TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix) } @@ -243,10 +243,10 @@ static size_t getValueSizeInBytes(mlir::Value value) { } void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const { - auto outBufAddr = memory.getValueAddress(vvaddOp.getOutBuf()); - auto aAddr = memory.getValueAddress(vvaddOp.getA()); - auto bAddr = memory.getValueAddress(vvaddOp.getB()); - setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); + auto outputBufferAddr = memory.getValueAddress(vvaddOp.getOutputBuffer()); + auto lhsAddr = memory.getValueAddress(vvaddOp.getLhs()); + auto rhsAddr = memory.getValueAddress(vvaddOp.getRhs()); + setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0); json::Object json; json["op"] = "vvadd"; @@ -254,15 +254,15 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const { json["rs1"] = 1; json["rs2"] = 2; json["offset"] = createEmptyOffset(); - json["len"] = getValueSizeInBytes(vvaddOp.getA()); + json["len"] = getValueSizeInBytes(vvaddOp.getLhs()); emitInstruction(std::move(json)); } void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const { - auto outBufAddr = memory.getValueAddress(vvsubOp.getOutBuf()); - auto aAddr = memory.getValueAddress(vvsubOp.getA()); - auto bAddr = memory.getValueAddress(vvsubOp.getB()); - setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); + auto outputBufferAddr = memory.getValueAddress(vvsubOp.getOutputBuffer()); + auto lhsAddr = memory.getValueAddress(vvsubOp.getLhs()); + auto rhsAddr = memory.getValueAddress(vvsubOp.getRhs()); + setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0); json::Object json; json["op"] = "vvsub"; @@ -270,15 +270,15 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const { json["rs1"] = 1; json["rs2"] = 2; json["offset"] = createEmptyOffset(); - json["len"] = getValueSizeInBytes(vvsubOp.getA()); + json["len"] = getValueSizeInBytes(vvsubOp.getLhs()); emitInstruction(std::move(json)); } void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const { - auto outBufAddr = memory.getValueAddress(vvmulOp.getOutBuf()); - auto aAddr = memory.getValueAddress(vvmulOp.getA()); - auto bAddr = memory.getValueAddress(vvmulOp.getB()); - setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); + auto outputBufferAddr = memory.getValueAddress(vvmulOp.getOutputBuffer()); + auto lhsAddr = memory.getValueAddress(vvmulOp.getLhs()); + auto rhsAddr = memory.getValueAddress(vvmulOp.getRhs()); + setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0); json::Object json; json["op"] = "vvmul"; @@ -286,15 +286,15 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const { json["rs1"] = 1; json["rs2"] = 2; json["offset"] = createEmptyOffset(); - json["len"] = getValueSizeInBytes(vvmulOp.getA()); + json["len"] = getValueSizeInBytes(vvmulOp.getLhs()); emitInstruction(std::move(json)); } void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const { - auto outBufAddr = memory.getValueAddress(vvmaxOp.getOutBuf()); - auto aAddr = memory.getValueAddress(vvmaxOp.getA()); - auto bAddr = memory.getValueAddress(vvmaxOp.getB()); - setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); + auto outputBufferAddr = memory.getValueAddress(vvmaxOp.getOutputBuffer()); + auto lhsAddr = memory.getValueAddress(vvmaxOp.getLhs()); + auto rhsAddr = memory.getValueAddress(vvmaxOp.getRhs()); + setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0); json::Object json; json["op"] = "vvmax"; @@ -302,15 +302,15 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const { json["rs1"] = 1; json["rs2"] = 2; json["offset"] = createEmptyOffset(); - json["len"] = getValueSizeInBytes(vvmaxOp.getA()); + json["len"] = getValueSizeInBytes(vvmaxOp.getLhs()); emitInstruction(std::move(json)); } void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const { - auto outBufAddr = memory.getValueAddress(vvdmulOp.getOutBuf()); - auto aAddr = memory.getValueAddress(vvdmulOp.getA()); - auto bAddr = memory.getValueAddress(vvdmulOp.getB()); - setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); + auto outputBufferAddr = memory.getValueAddress(vvdmulOp.getOutputBuffer()); + auto lhsAddr = memory.getValueAddress(vvdmulOp.getLhs()); + auto rhsAddr = memory.getValueAddress(vvdmulOp.getRhs()); + setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0); json::Object json; json["op"] = "vvdmul"; @@ -318,132 +318,71 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const { json["rs1"] = 1; json["rs2"] = 2; json["offset"] = createEmptyOffset(); - json["len"] = getValueSizeInBytes(vvdmulOp.getA()); + json["len"] = getValueSizeInBytes(vvdmulOp.getLhs()); emitInstruction(std::move(json)); } void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const { - auto outBufAddr = memory.getValueAddress(vavgOp.getOutBuf()); - auto aAddr = memory.getValueAddress(vavgOp.getA()); - setupRdRs1(outBufAddr, 0, aAddr, 0); + auto outputBufferAddr = memory.getValueAddress(vavgOp.getOutputBuffer()); + auto inputAddr = memory.getValueAddress(vavgOp.getInput()); + setupRdRs1(outputBufferAddr, 0, inputAddr, 0); json::Object json; json["op"] = "vavg"; json["rd"] = 0; json["rs1"] = 1; json["offset"] = createEmptyOffset(); - json["len"] = getValueSizeInBytes(vavgOp.getA()); + json["len"] = getValueSizeInBytes(vavgOp.getInput()); emitInstruction(std::move(json)); } void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const { - auto outBufAddr = memory.getValueAddress(vreluOp.getOutBuf()); - auto aAddr = memory.getValueAddress(vreluOp.getA()); - setupRdRs1(outBufAddr, 0, aAddr, 0); + auto outputBufferAddr = memory.getValueAddress(vreluOp.getOutputBuffer()); + auto inputAddr = memory.getValueAddress(vreluOp.getInput()); + setupRdRs1(outputBufferAddr, 0, inputAddr, 0); json::Object json; json["op"] = "vrelu"; json["rd"] = 0; json["rs1"] = 1; json["offset"] = createEmptyOffset(); - json["len"] = getValueSizeInBytes(vreluOp.getA()); + json["len"] = getValueSizeInBytes(vreluOp.getInput()); emitInstruction(std::move(json)); } void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const { - auto outBufAddr = memory.getValueAddress(vtanhOp.getOutBuf()); - auto aAddr = memory.getValueAddress(vtanhOp.getA()); - setupRdRs1(outBufAddr, 0, aAddr, 0); + auto outputBufferAddr = memory.getValueAddress(vtanhOp.getOutputBuffer()); + auto inputAddr = memory.getValueAddress(vtanhOp.getInput()); + setupRdRs1(outputBufferAddr, 0, inputAddr, 0); json::Object json; json["op"] = "vtanh"; json["rd"] = 0; json["rs1"] = 1; json["offset"] = createEmptyOffset(); - json["len"] = getValueSizeInBytes(vtanhOp.getA()); + json["len"] = getValueSizeInBytes(vtanhOp.getInput()); emitInstruction(std::move(json)); } void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const { - auto outBufAddr = memory.getValueAddress(vsigmOp.getOutBuf()); - auto aAddr = memory.getValueAddress(vsigmOp.getA()); - setupRdRs1(outBufAddr, 0, aAddr, 0); + auto outputBufferAddr = memory.getValueAddress(vsigmOp.getOutputBuffer()); + auto inputAddr = memory.getValueAddress(vsigmOp.getInput()); + setupRdRs1(outputBufferAddr, 0, inputAddr, 0); json::Object json; json["op"] = "vsigm"; json["rd"] = 0; json["rs1"] = 1; json["offset"] = createEmptyOffset(); - json["len"] = getValueSizeInBytes(vsigmOp.getA()); + json["len"] = getValueSizeInBytes(vsigmOp.getInput()); emitInstruction(std::move(json)); } -void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const { - auto outBufAddr = memory.getValueAddress(applyFiltersOp.getOutBuf()); - auto inBufAddr = memory.getValueAddress(applyFiltersOp.getInput()); - auto accumBufAddr = memory.getValueAddress(applyFiltersOp.getAccumBuf()); - - auto weightIndices = applyFiltersOp.getWeightIndices(); - - auto inputType = cast(applyFiltersOp.getInput().getType()); - auto outputType = cast(applyFiltersOp.getOutBuf().getType()); - auto inShape = inputType.getShape(); - auto outShape = outputType.getShape(); - - size_t inChannels = inShape[1]; - size_t outChannels = outShape[1]; - size_t dimX = inShape.size() > 2 ? inShape[2] : 1; - size_t dimY = inShape.size() > 3 ? inShape[3] : 1; - - for (size_t outY = 0; outY < dimY; outY++) { - for (size_t outX = 0; outX < dimX; outX++) { - - size_t weightIndex = 0; - for (Attribute weight : weightIndices) { - // --- STEP 1: Perform MVMUL operation --- - auto weightId = cast(weight).getInt(); - size_t xKer = cast(applyFiltersOp.getXKernelPositions()[weightIndex]).getInt(); - size_t yKer = cast(applyFiltersOp.getYKernelPositions()[weightIndex]).getInt(); - weightIndex++; - - if (outX + xKer >= dimX || outY + yKer >= dimY) - continue; - - size_t outputOffset = (outY * dimX + outX) * 32 * outChannels; - size_t inputOffset = ((outY + yKer) * dimX + (outX + xKer)) * 32 * inChannels; - - bool isFirstWeight = (weightIndices[0] == weight); - - // For the first weight, store directly in output buffer; otherwise use accumulator. - size_t rdAddr = isFirstWeight ? outBufAddr : accumBufAddr; - size_t rdOffset = isFirstWeight ? outputOffset : 0; - emitMvmOp(weightId, rdAddr, rdOffset, inBufAddr, inputOffset); - - // --- STEP 2: Perform VADD operation (skip for first weight) --- - if (isFirstWeight) - continue; - - // Sum accumulator with output buffer, store result in output buffer. - setupRdRs1Rs2(outBufAddr, outputOffset, accumBufAddr, 0, outBufAddr, outputOffset); - - json::Object vaddJson; - vaddJson["op"] = "vvadd"; - vaddJson["rd"] = 0; - vaddJson["rs1"] = 1; - vaddJson["rs2"] = 2; - vaddJson["offset"] = createEmptyOffset(); - vaddJson["len"] = 32 * outChannels; - emitInstruction(std::move(vaddJson)); - } - } - } -} - void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const { - auto srcAddr = memory.getValueAddress(transposeOp.getData()); - auto dstAddr = memory.getValueAddress(transposeOp.getOutBuf()); + auto srcAddr = memory.getValueAddress(transposeOp.getInput()); + auto dstAddr = memory.getValueAddress(transposeOp.getOutputBuffer()); - auto srcType = cast(transposeOp.getData().getType()); + auto srcType = cast(transposeOp.getInput().getType()); auto srcShape = srcType.getShape(); size_t rank = srcShape.size(); size_t elementSize = srcType.getElementTypeBitWidth() / 8; @@ -451,7 +390,7 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const { // Read permutation. Destination dim i corresponds to source dim perm[i]. SmallVector perm = - map_to_vector(transposeOp.getPerms().getAsRange(), [](auto attr) -> int64_t { return attr.getInt(); }); + map_to_vector(transposeOp.getPermutation().getAsRange(), [](auto attr) -> int64_t { return attr.getInt(); }); // Destination shape: dstShape[i] = srcShape[perm[i]] SmallVector dstShape(rank); @@ -570,8 +509,6 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenMVMLikeOp(vmmOp.getWeightIndex(), vmmOp, true); else if (auto mvmOp = dyn_cast(op)) coreCodeGen.codeGenMVMLikeOp(mvmOp.getWeightIndex(), mvmOp, false); - else if (auto applyFiltersOp = dyn_cast(op)) - coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp); else if (auto transposeOp = dyn_cast(op)) coreCodeGen.codeGenTransposeOp(transposeOp); else if (auto vvaddOp = dyn_cast(op)) @@ -592,11 +529,6 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenVTanhOp(vtanhOp); else if (auto vsigmOp = dyn_cast(op)) coreCodeGen.codeGenVSigmOp(vsigmOp); - else if (isa(op)) { - // TODO: Implement somehow? - op.emitWarning("Operation is not yet supported in code generation"); - continue; - } else { op.emitError("Unsupported codegen for this operation"); op.dump(); diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index e08dfda..ab0f1a0 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -99,7 +99,6 @@ public: void codeGenVReluOp(pim::PimVReluOp vreluOp) const; void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const; void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const; - void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const; void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const; }; diff --git a/src/PIM/Conversion/ONNXToSpatial/Common.cpp b/src/PIM/Conversion/ONNXToSpatial/Common.cpp index 03cc524..066623e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common.cpp @@ -134,366 +134,4 @@ Value sumTensors(ArrayRef tensors, ConversionPatternRewriter& rewriter) { return (*currTensors)[0]; } -Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) { - switch (mapOp) { - case MapOperations::None: assert(false && "Invalid map operation during map operation creation."); - case MapOperations::ONNXSoftmaxOp: return ONNXSoftmaxOp::create(rewriter, input.getLoc(), input.getType(), input); - case MapOperations::ONNXReluOp: return ONNXReluOp::create(rewriter, input.getLoc(), input.getType(), input); - case MapOperations::ONNXLeakyReluOp: return ONNXLeakyReluOp::create(rewriter, input.getLoc(), input.getType(), input); - case MapOperations::ONNXExpOp: return ONNXExpOp::create(rewriter, input.getLoc(), input.getType(), input); - } -} - -void unpackOptionalPairVector(std::optional valuesArray, size_t& value1, size_t& value2) { - if (auto unpackedStrides = valuesArray) { - value1 = mlir::cast(unpackedStrides->getValue()[0]).getInt(); - value2 = mlir::cast(unpackedStrides->getValue()[1]).getInt(); - } - else { - value1 = 1; - value2 = 1; - } -} - -std::optional -unpackOptionalPadsVector(std::optional valuesArray, size_t& pad_x, size_t& pad_y) { - if (valuesArray.has_value()) { - auto pads = mlir::ArrayAttr(*valuesArray); - if (pads.size() != 4) - return "pads must have 4 elements."; - - pad_x = cast(pads[2]).getInt(); - pad_y = cast(pads[3]).getInt(); - } - else { - // Default padding is 0 unless specified otherwise. - // https://onnx.ai/onnx/operators/onnx__Conv.html - pad_x = pad_y = 0; - } - - return std::nullopt; -} - -void tileImageTensorByChannel(Value imageTensor, - SmallVector>>& tiles, - size_t tileSize, - ConversionPatternRewriter& rewriter) { - ShapedType imageShape = mlir::cast(imageTensor.getType()); - - size_t input_h = getImageHeight(imageShape); - size_t input_w = getImageWidth(imageShape); - size_t tileCount = ceilIntegerDivide(getImageChannel(imageShape), tileSize); - size_t tileRest = getImageChannel(imageShape) % tileSize; - - SmallVector strides(4, rewriter.getIndexAttr(1)); - SmallVector offsets(4, rewriter.getIndexAttr(0)); - SmallVector sizes = { - rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - - Location loc = imageTensor.getLoc(); - - for (size_t i = 0; i < tileCount; i++) { - if (i == tileCount - 1 && tileRest != 0) - sizes[1] = rewriter.getIndexAttr(tileRest); - for (size_t x = 0; x < input_w; x++) { - for (size_t y = 0; y < input_h; y++) { - offsets[1] = rewriter.getIndexAttr(i * tileSize); - offsets[2] = rewriter.getIndexAttr(x); - offsets[3] = rewriter.getIndexAttr(y); - - tiles[i][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, imageTensor, offsets, sizes, strides); - } - } - } -} - -Value createImgConcatOp(SmallVector>>& outputTiles, - ConversionPatternRewriter& rewriter, - Location& loc, - Type outputType) { - // Populate the outputTiles for the concat in the given order: - // 1. Start top left pixel - // 2. Continue on its right pixel till the end of the row - // 3. Restart on the next row - size_t outputTileCount = outputTiles.size(); - size_t output_w = outputTiles[0].size(); - size_t output_h = outputTiles[0][0].size(); - SmallVector tilesToConcat; - tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize); - for (size_t outX = 0; outX < output_h; outX++) - for (size_t outY = 0; outY < output_w; outY++) - for (size_t outTile = 0; outTile < outputTileCount; outTile++) - tilesToConcat.push_back(outputTiles[outTile][outX][outY]); - - return spatial::SpatImgConcatOp::create(rewriter, loc, outputType, tilesToConcat); -} - -LogicalResult -verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y) { - - if (inX < 0) { - assert((size_t) (-inX) <= pad_x && "verifyWithinBoundsAndPaddings: Negative x value out of padding"); - return failure(); - } - - if (inY < 0) { - assert((size_t) (-inY) <= pad_y && "verifyWithinBoundsAndPaddings: Negative y value out of padding"); - return failure(); - } - - if ((size_t) inX >= input_w || (size_t) inY >= input_h) { - assert((size_t) inX < input_w + pad_x && "verifyWithinBoundsAndPaddings: Positive x out of bounds"); - assert((size_t) inY < input_h + pad_y && "verifyWithinBoundsAndPaddings: Positive y out of bounds"); - return failure(); - } - - return success(); -} - -Value createExtractSliceImg(Value valToSlice, - size_t x, - size_t y, - size_t t, - size_t channelTileCount, - size_t channelTileRest, - size_t input_w, - size_t input_h, - PatternRewriter& rewriter) { - SmallVector strides(4, rewriter.getIndexAttr(1)); - SmallVector offsets(4, rewriter.getIndexAttr(0)); - SmallVector sizes = { - rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - - if (t == channelTileCount - 1 && channelTileRest != 0) - sizes[1] = rewriter.getIndexAttr(channelTileRest); - - offsets[1] = rewriter.getIndexAttr(t * crossbarSize); - offsets[2] = rewriter.getIndexAttr(x); - offsets[3] = rewriter.getIndexAttr(y); - - return tensor::ExtractSliceOp::create(rewriter, valToSlice.getLoc(), valToSlice, offsets, sizes, strides); -} - -Value indexImgValue(Value v, - size_t x, - size_t y, - size_t t, - size_t channelTileCount, - size_t channelTileRest, - size_t input_w, - size_t input_h, - ConversionPatternRewriter& rewriter) { - - auto newV = rewriter.getRemappedValue(v); - if (newV) - v = newV; - - if (!v.getDefiningOp()) - return createExtractSliceImg(v, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter); - - if (auto computeOp = v.getDefiningOp()) { - // We found the computeOp that produces the tile we want, just return this - // value. - // TODO: Should we assert that x,y,t are zero? - assert(x == 0 && y == 0 && t == 0 && "indexImgValue: WeightedComputeOp tile indeces should be zero"); - return v; - } - - if (auto receiveOp = v.getDefiningOp()) { - // This is a receiveOp, just return its value which will be resolved later - assert(x == 0 && y == 0 && t == 0 && "indexImgValue: receiveOp tile indeces should be zero"); - return v; - } - - if (auto imgConcatOp = v.getDefiningOp()) { - auto imgConcatInput = imgConcatOp.getInputTile(x, y, t); - // TODO: Is this correct? - // Above we already index exactly the tile we want, so `x=y=t=0` in - // recursive call - - return indexImgValue(imgConcatInput, 0, 0, 0, channelTileCount, channelTileRest, input_w, input_h, rewriter); - } - - if (auto tensorConcatOp = v.getDefiningOp()) { - // This can be recursive. - // First, get the input tensors of the tensor.concatOp - // Then, find the input tensor that contains the tile we want - // Finally, recursive call asking for the tile - auto concatAxis = tensorConcatOp.getDim(); - assert(concatAxis != 0 && "Expecting to concat on channel/x/y axis"); - assert(concatAxis == 1 && "TODO: Make sure this works and makes sense for other axis."); - SmallVector indexDims = {1, t * crossbarSize, x, y}; - - // Find the input tensor that contains the tile we want - size_t currentTile = 0; - for (auto concatInput : tensorConcatOp.getInputs()) { - auto concatInputShape = cast(concatInput.getType()); - assert(concatInputShape.getRank() == 4 && "Expecting an image tensor"); - auto concatInputSizeOnAxis = concatInputShape.getDimSize(concatAxis); - - if (currentTile + concatInputSizeOnAxis > indexDims[concatAxis]) { - // This input tensor contains the tile we want - indexDims[concatAxis] -= currentTile; - if (indexDims[1] % crossbarSize != 0) { - assert(ignoreConcatError - && "TODO: Handle non-tile aligned tensor, or set " - "--ignore-concat-error=true"); - } - return indexImgValue(concatInput, - indexDims[2], - indexDims[3], - indexDims[1] / crossbarSize, - channelTileCount, - channelTileRest, - input_w, - input_h, - rewriter); - } - currentTile += concatInputSizeOnAxis; - } - - assert(false - && "Could not find the input tensor that contains the tile " - "within tensor.ConcatOp"); - } - - v.dump(); - - assert(false && "indexImgValue: unsupported operation"); -} - -void resolveInputTensorTilesBlockArg(Value wholeInputTensor, - SmallVector>>& inputTiles, - size_t channelTileCount, - size_t channelTileRest, - size_t input_w, - size_t input_h, - PatternRewriter& rewriter) { - SmallVector strides(4, rewriter.getIndexAttr(1)); - SmallVector offsets(4, rewriter.getIndexAttr(0)); - SmallVector sizes = { - rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Location loc = wholeInputTensor.getLoc(); - - for (size_t t = 0; t < channelTileCount; t++) { - if (t == channelTileCount - 1 && channelTileRest != 0) - sizes[1] = rewriter.getIndexAttr(channelTileRest); - for (size_t x = 0; x < input_w; x++) { - for (size_t y = 0; y < input_h; y++) { - offsets[1] = rewriter.getIndexAttr(t * crossbarSize); - offsets[2] = rewriter.getIndexAttr(x); - offsets[3] = rewriter.getIndexAttr(y); - - inputTiles[t][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, wholeInputTensor, offsets, sizes, strides); - } - } - } -} - -std::optional resolveImgInputTiles(Value wholeInputTensor, - SmallVector>>& inputTiles, - size_t channelTileCount, - size_t channelTileRest, - size_t input_w, - size_t input_h, - ConversionPatternRewriter& rewriter) { - - for (size_t t = 0; t < channelTileCount; t++) { - for (size_t x = 0; x < input_w; x++) { - for (size_t y = 0; y < input_h; y++) { - inputTiles[t][x][y] = - indexImgValue(wholeInputTensor, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter); - } - } - } - - return std::nullopt; -} - -LogicalResult handleFlattenLikeOp(SmallVector>& inputTiles, - const size_t inputTilesCount, - const size_t lastInputTileDimension, - TensorType inputShape, - TensorType outputShape, - Value reshapeInput, - ConversionPatternRewriter& rewriter) { - // Only support reshape between an image and a vector (i.e. flatten) - if (inputShape.getRank() != 4 || outputShape.getRank() != 2) { - return rewriter.notifyMatchFailure(reshapeInput.getDefiningOp(), - "resolveVecInputTiles only supports reshapes from 4D to 2D tensors"); - } - - /* - * From a 4D tensor to a 2D tensor - */ - auto N = inputShape.getDimSize(0); - auto C = inputShape.getDimSize(1); - auto H = inputShape.getDimSize(2); - auto W = inputShape.getDimSize(3); - assert(N == 1 && "Only support N = 1 for image tensors"); - - for (size_t i = 0; i < inputTilesCount; i++) { - auto c = (i / (H * W)) % C; - // TODO: Is this correct? Or should I invert h and w? - auto w = (i / H) % W; - auto h = i % H; - - Value curTile = indexImgValue(reshapeInput, w, h, c, inputTilesCount, lastInputTileDimension, W, H, rewriter); - - // Assert the shape of the tile, and reshape it - auto curTileShape = cast(curTile.getType()); - assert(curTileShape.getRank() == 4 && "We just reshaped an image tensor, why rank != 4?"); - assert(curTileShape.getDimSize(0) == 1 && "We just reshaped an image tensor with N = 1, why is it now != 1?"); - assert(curTileShape.getDimSize(2) == 1 && "We should have just looked up a single pixel why W != 1?"); - assert(curTileShape.getDimSize(3) == 1 && "We should have just looked up a single pixel why H != 1?"); - - // Reshape this pixel tensor into a vector, for compatibility with the - // rest - SmallVector newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)}; - auto shapeType = RankedTensorType::get({static_cast(newShapeVals.size())}, rewriter.getI64Type()); - Value shapeTensor = - arith::ConstantOp::create(rewriter, reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals)); - auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType()); - auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor); - - size_t coreIndex = i / crossbarCountInCore; - inputTiles[coreIndex].push_back(reshapedCurTile); - } - - return success(); -} - -std::pair kernel_get_start_and_end( - int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad) { - int64_t firstValid = std::ceil(static_cast(pad) / dilation) * dilation - pad; - int64_t start = std::max(firstValid, out_pos * stride - pad); - int64_t end = std::min(input_width, out_pos * stride + (krn_width - 1) * dilation + 1 - pad); - - assert(start >= 0 && "Start position must be non-negative."); - assert(end >= 0 && "End position must be non-negative."); - return std::make_pair(start, end); -} - -void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment) { - auto oldSegmentSizes = wcomputeOp->getAttrOfType(wcomputeOp.getOperandSegmentSizesAttrName()); - - auto newSegmentSizes = - DenseI32ArrayAttr::get(wcomputeOp->getContext(), {oldSegmentSizes[0], oldSegmentSizes[1] + increment}); - - wcomputeOp->setAttr(wcomputeOp.getOperandSegmentSizesAttrName(), newSegmentSizes); -} - -int getResultIndex(Operation* op, Value v) { - int resultNumber = -1; - for (auto result : op->getResults()) { - if (result == v) { - resultNumber = result.getResultNumber(); - break; - } - } - assert(resultNumber >= 0 && "Value not found in given operation's results."); - - return resultNumber; -} - }; // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common.hpp b/src/PIM/Conversion/ONNXToSpatial/Common.hpp index 3acddb5..2e2b801 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common.hpp @@ -2,7 +2,6 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" @@ -10,7 +9,6 @@ #include "llvm/Support/LogicalResult.h" #include -#include #include #include #include @@ -144,164 +142,4 @@ mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast, mlir::Value sumTensors(mlir::ArrayRef tensors, mlir::ConversionPatternRewriter& rewriter); -mlir::Value createMapOperation(mlir::PatternRewriter& rewriter, MapOperations mapOp, const mlir::Value& input); - -/** - * Unpacks an optional pair vector into two size_t values. - * - * @param valuesArray The optional `mlir::ArrayAttr` containing the pair of - * values. - * @param value1 The reference to the first `size_t` variable to store the - * unpacked value. - * @param value2 The reference to the second `size_t` variable to store the - * unpacked value. - */ -void unpackOptionalPairVector(std::optional valuesArray, size_t& value1, size_t& value2); - -/** - * Unpacks the optional pads vector. - * - * @param valuesArray The optional array attribute containing the values. - * @param pad_x The output variable to store the value of pad_x. - * @param pad_y The output variable to store the value of pad_y. - * @param rewriter The rewriter to notify failure - * - * @return llvm::Optional The error message if the pads are invalid - */ -std::optional -unpackOptionalPadsVector(std::optional valuesArray, size_t& pad_x, size_t& pad_y); - -/** - * Tiles the image tensor by channel. - * - * This function takes an image tensor and tiles it into smaller tiles based on - * the channel dimension. The size of each tile is specified by the tileSize - * parameter. - * - * @param imageTensor The input image tensor (NxCxWxH) to be tiled. - * @param tiles The output tiles vector to store the tiled image tensors. - * @param tileSize The size of each tile. - * @param rewriter The ConversionPatternRewriter used for creating operations. - */ -void tileImageTensorByChannel(mlir::Value imageTensor, - llvm::SmallVector>>& tiles, - size_t tileSize, - mlir::ConversionPatternRewriter& rewriter); - -/** - * Creates an ImgConcatOp based on the given tiles. - * - * This function takes a 3-dimensional vector `outputTiles` representing the - * tiles to concatenate. The tiles are indexed by [tile][x][y]. - * - * @param outputTiles The tiles to concatenate. - * @param rewriter The ConversionPatternRewriter used for creating the - * ImgConcatOp. - * @param loc The location of the operation. - * @param outputType The type of the output tensor. - * - * @return The created ImgConcatOp. - */ -mlir::Value createImgConcatOp(llvm::SmallVector>>& outputTiles, - mlir::ConversionPatternRewriter& rewriter, - mlir::Location& loc, - mlir::Type outputType); - -/** - * @brief Verifies if the given input coordinates and padding values are within - * the bounds of the input tensor. - * - * @param input_w The width of the input tensor. - * @param input_h The height of the input tensor. - * @param inX The X-coordinate of the input. - * @param inY The Y-coordinate of the input. - * @param pad_x The padding value in the X-direction. - * @param pad_y The padding value in the Y-direction. - * @return LogicalResult Returns success if the coordinates and padding are - * within bounds, failure otherwise. - */ -mlir::LogicalResult -verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y); - -/** - * Resolves the tiling of the input tensor into smaller tiles. - * - * This function takes a whole input tensor and tiles it into smaller tiles - * using the provided parameters. The resulting tiles are stored in the - * `inputTiles` vector. - * Input tiles need to be indexed by: - * a. Channel Tile - * b. Pixel `x` position - * c. Pixel `y` position - * For example: inputTiles[channelTile][x][y] - * - * @param wholeInputTensor The whole input tensor to be tiled. - * @param inputTiles A vector of vectors of vectors of Values representing the - * tiles of the input tensor. The outermost vector represents - * the channels, the middle vector represents the rows, and - * the innermost vector represents the columns of the tiles. - * @param channelTileCount The number of tiles for the `channel` axis. - * @param channelTileRest The size of the last channelTile. Set as 0 if tiles - * fit exactly - * @param input_w The width of the input tensor. - * @param input_h The height of the input tensor. - * @param rewriter The ConversionPatternRewriter used for creating operations. - * - * @return std::optional An error message if the input tensor could - * not be resolved into tiles. - */ -std::optional -resolveImgInputTiles(mlir::Value wholeInputTensor, - llvm::SmallVector>>& inputTiles, - size_t channelTileCount, - size_t channelTileRest, - size_t input_w, - size_t input_h, - mlir::ConversionPatternRewriter& rewriter); - -/** - * Computes the boundaries of an image kernel application. - * - * @param out_pos The position of the output element. - * @param input_width The width of the input image. - * @param krn_width The width of the kernel. - * @param stride The stride value. - * @param dilation The dilation value. - * @param pad The padding value. - * @return A pair of size_t values representing the start and end positions of - * the kernel application. - */ -std::pair kernel_get_start_and_end( - int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad); - -/** - * @brief Increment the `operandSegmentSizes` in the WeightedCompute operation - * for the `inputs` operand. - * - * This function increments the size of the `inputs` operand segment in the - * `operandSegmentSizes` of the given WeightedCompute operation by the specified - * increment. This is necessary when new operands are programmatically added to - * the WeightedCompute operation. - * - * @param wcomputeOp The WeightedCompute operation whose `operandSegmentSizes` - * is to be incremented. - * @param increment The value by which to increment the `inputs` operand segment - * size. - */ -void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment); - -/** - * @brief Finds the result index of the given operation that produces the - * specified value. - * - * This function takes an operation and a value, and returns the index of the - * result of the operation that corresponds to the given value. - * - * @param op Operation whose result index is to be found. - * @param v The value for which the result index is to be determined. - * @return The index of the result of the operation that produces the specified - * value. - */ -int getResultIndex(mlir::Operation* op, mlir::Value v); - }; // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp index db4a0b6..bbe8aa0 100644 --- a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp +++ b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp @@ -223,9 +223,6 @@ void SpatialToGraphvizPass::runOnOperation() { else if (auto concatOp = dyn_cast(op)) { drawConcatOpSubgraph(concatOp, concatNum++); } - else if (auto imgConcatOp = dyn_cast(op)) { - drawConcatOpSubgraph(imgConcatOp, concatNum++); - } else if (auto extractSliceOp = dyn_cast(op)) { auto producerOp = extractSliceOp->getOperand(0).getDefiningOp(); if (producerOp) { diff --git a/src/PIM/Conversion/SpatialToPim/Common.hpp b/src/PIM/Conversion/SpatialToPim/Common.hpp index 558f919..be8fe5a 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.hpp +++ b/src/PIM/Conversion/SpatialToPim/Common.hpp @@ -45,8 +45,4 @@ createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType()); } -inline bool isAConcatOp(mlir::Operation* op) { - return llvm::isa(op) || llvm::isa(op); -} - } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index c97932a..88eadbc 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -45,4 +45,5 @@ def spatToPimVVMaxOp : Pat< (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; + #endif // SPATIAL_TO_PIM diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 2e93989..fedb72d 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -278,7 +278,7 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR continue; } - if (isa(resultUser) || isa(resultUser)) { + if (isa(resultUser)) { auto concatOp = resultUser; auto concatValue = concatOp->getResult(0); auto concatUses = concatValue.getUses(); @@ -368,8 +368,8 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I }; funcOp.walk([&](PimVMMOp vmmOp) { - auto outTensorOperand = vmmOp.getOutBuf(); - auto resultTensor = vmmOp.getOutRes(); + auto outTensorOperand = vmmOp.getOutputBuffer(); + auto resultTensor = vmmOp.getOutput(); auto outShape = getTensorShape(outTensorOperand); assert(isHVectorShape(outShape)); if (outShape[1] != static_cast(crossbarSize)) { @@ -602,9 +602,7 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); }); - // If the operand is a concatenation operation and the returnOp was the only - // user of the returnOperand, we can safely remove it - if (isAConcatOp(returnOperand)) { + if (isa(returnOperand)) { auto returnOperandUses = it.value().getUses(); if (rangeLength(returnOperandUses) == 0) rewriter.eraseOp(returnOperand); @@ -632,7 +630,7 @@ void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I // user. This means that we need to get the replace the original SendOp with // a BroadcastSendOp rewriter.setInsertionPoint(sendOp); - rewriter.replaceOpWithNewOp(sendOp, sendOp.getChannel(), sendOp.getData()); + rewriter.replaceOpWithNewOp(sendOp, sendOp.getChannel(), sendOp.getInput()); } } diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 99f2a58..616da86 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -20,51 +20,18 @@ class PimOp traits = []> : def PimTensor : AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; -// Communication +//===----------------------------------------------------------------------===// +// Execution +//===----------------------------------------------------------------------===// -def PimSendOp: PimOp<"send", []> { - let arguments = (ins - PimTensor: $src, - I32Attr: $size, - I32Attr: $targetCoreId - ); - - let assemblyFormat = [{ - `(` $src `)` attr-dict `:` type($src) `->` `(` `)` - }]; -} - -def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> { - let arguments = (ins - PimTensor: $dst, - I32Attr: $size, - I32Attr: $srcCoreId - ); - - let results = (outs - PimTensor: $out - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getDstMutable(); - } - }]; - - let assemblyFormat = [{ - `(` $dst `)` attr-dict `:` type($dst) `->` type($out) - }]; -} - -// Core - -def PimCoreOp: PimOp<"core", [SingleBlock]> { +def PimCoreOp : PimOp<"core", [SingleBlock]> { + let summary = "Execute a block on a PIM core"; let regions = (region SizedRegion<1>:$body); let arguments = (ins Variadic:$weights, - I32Attr: $coreId + I32Attr:$coreId ); let assemblyFormat = [{ @@ -72,412 +39,443 @@ def PimCoreOp: PimOp<"core", [SingleBlock]> { }]; } -// Memory - -def PimConstantOp: PimOp<"constant", []> { - let description = [{ - Allocate a constant value in global memory - }]; - - let arguments = (ins - AnyAttr: $value, - BoolAttr: $shouldAllocate - ); - - let results = (outs - PimTensor: $out - ); -} - -def PimMemCopyHostToDevOp: PimOp<"memcp_hd", [DestinationStyleOpInterface]> { - let description = [{ - Copy a memory region from host memory into device memory - }]; - - let arguments = (ins - PimTensor: $deviceDst, - PimTensor: $hostSrc, - I32Attr: $deviceDstOffset, - I32Attr: $hostSrcOffset, - I32Attr: $size - ); - - let results = (outs - PimTensor: $deviceDstOut - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getDeviceDstMutable(); - } - }]; - - let assemblyFormat = [{ - `(` $deviceDst `,` $hostSrc `)` attr-dict `:` `(` type($deviceDst) `,` type($hostSrc) `)` `->` type($deviceDstOut) - }]; -} - -def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> { - let description = [{ - Copy a memory region from device memory into host memory - }]; - - let arguments = (ins - PimTensor: $hostDst, - PimTensor: $deviceSrc, - I32Attr: $hostDstOffset, - I32Attr: $deviceSrcOffset, - I32Attr: $size - ); - - let results = (outs - PimTensor: $hostDstOut - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getHostDstMutable(); - } - }]; - - - let assemblyFormat = [{ - `(` $hostDst `,` $deviceSrc `)` attr-dict `:` `(` type($hostDst) `,` type($deviceSrc) `)` `->` type($hostDstOut) - }]; -} - -def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> { - let description = [{ - Copy a memory region from and to the same memory - }]; - - let arguments = (ins - PimTensor: $dst, - PimTensor: $src, - I32Attr: $dstOffset, - I32Attr: $srcOffset, - I32Attr: $size - ); - - let results = (outs - PimTensor: $dstOut - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getDstMutable(); - } - }]; - - - let assemblyFormat = [{ - `(` $dst `,` $src `)` attr-dict `:` `(` type($dst) `,` type($src) `)` `->` type($dstOut) - }]; -} - -// Algebra - -def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> { - let description = [{ - Matrix transpose - }]; - - let arguments = (ins - PimTensor: $data, - I64ArrayAttr: $perms, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getOutBufMutable(); - } - }]; - - let assemblyFormat = [{ - `(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes) - }]; -} - -def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> { - let description = [{ - Vector-matrix multiplication: c = a * b - }]; - - let arguments = (ins - I32Attr: $weightIndex, - PimTensor: $vectorInput, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getOutBufMutable(); - } - }]; - - let assemblyFormat = [{ - `(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes) - }]; -} - -def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> { - let description = [{ - Matrix-vector multiplication: c = a * b - }]; - - let arguments = (ins - I32Attr: $weightIndex, - PimTensor: $vectorInput, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getOutBufMutable(); - } - }]; -} - -def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> { - let description = [{ - Element-wise addition: c = a + b - }]; - - let arguments = (ins - PimTensor: $a, - PimTensor: $b, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getOutBufMutable(); - } - }]; - - let assemblyFormat = [{ - `(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes) - }]; -} - -def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> { - let description = [{ - Element-wise subtraction: c = a - b - }]; - - let arguments = (ins - PimTensor: $a, - PimTensor: $b, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getOutBufMutable(); - } - }]; - - let assemblyFormat = [{ - `(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes) - }]; -} - -def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> { - let description = [{ - Element-wise multiplication: c = a * b - }]; - - let arguments = (ins - PimTensor: $a, - PimTensor: $b, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getOutBufMutable(); - } - }]; - - let assemblyFormat = [{ - `(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes) - }]; -} - -def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> { - let description = [{ - Element-wise max: c = max(a, b) - }]; - - let arguments = (ins - PimTensor: $a, - PimTensor: $b, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getOutBufMutable(); - } - }]; - - let assemblyFormat = [{ - `(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes) - }]; -} - -def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods]> { - let description = [{ - Dot product: c = dot(a, b) - }]; - - let arguments = (ins - PimTensor: $a, - PimTensor: $b, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); -} - -def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods]> { - let description = [{ - Apply filters to a tensor - }]; - - let arguments = (ins - I64ArrayAttr: $weightIndices, - I64ArrayAttr: $xKernelPositions, - I64ArrayAttr: $yKernelPositions, - PimTensor: $input, - PimTensor: $outBuf, - PimTensor: $accumBuf - ); - - let results = (outs - PimTensor: $outRes - ); - - let assemblyFormat = [{ - `(` `input` `=` $input `,` `outBuf` `=` $outBuf `,` `accumBuf` `=` $accumBuf `)` attr-dict `:` - type($input) `,` type($outBuf) `,` type($accumBuf) `->` type($outRes) - }]; -} - -def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods]> { - let description = [{ - Sum all elements into a single one - }]; - - let arguments = (ins - PimTensor: $a, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); -} - -def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods]> { - let description = [{ - Average all elements into a single one - }]; - - let arguments = (ins - PimTensor: $a, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); -} - -def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods]> { - let description = [{ - Element-wise ReLU: c = max(a, 0) - }]; - - let arguments = (ins - PimTensor: $a, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); -} - -def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods]> { - let description = [{ - Element-wise tanh activation - }]; - - let arguments = (ins - PimTensor: $a, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); -} - -def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods]> { - let description = [{ - Element-wise sigmoid activation - }]; - - let arguments = (ins - PimTensor: $a, - PimTensor: $outBuf - ); - - let results = (outs - PimTensor: $outRes - ); -} - -def PimHaltOp: PimOp<"halt", [Terminator]> { - let description = [{ - Halts the execution of the core - }]; +def PimHaltOp : PimOp<"halt", [Terminator]> { + let summary = "Halt execution of the core"; let assemblyFormat = [{ attr-dict }]; } +//===----------------------------------------------------------------------===// +// Communication +//===----------------------------------------------------------------------===// + +def PimSendOp : PimOp<"send", []> { + let summary = "Send a tensor to another core"; + + let arguments = (ins + PimTensor:$input, + I32Attr:$size, + I32Attr:$targetCoreId + ); + + let assemblyFormat = [{ + `(` $input `)` attr-dict `:` type($input) `->` `(` `)` + }]; +} + +def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> { + let summary = "Receive a tensor from another core"; + + let arguments = (ins + PimTensor:$outputBuffer, + I32Attr:$size, + I32Attr:$sourceCoreId + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output) + }]; +} + +def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> { + let summary = "Copy a memory region from host memory into device memory"; + + let arguments = (ins + PimTensor:$deviceTarget, + PimTensor:$hostSource, + I32Attr:$deviceTargetOffset, + I32Attr:$hostSourceOffset, + I32Attr:$size + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getDeviceTargetMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output) + }]; +} + +def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> { + let summary = "Copy a memory region from device memory into host memory"; + + let arguments = (ins + PimTensor:$hostTarget, + PimTensor:$deviceSource, + I32Attr:$hostTargetOffset, + I32Attr:$deviceSourceOffset, + I32Attr:$size + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getHostTargetMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output) + }]; +} + +def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> { + let summary = "Copy a memory region within the same memory space"; + + let arguments = (ins + PimTensor:$target, + PimTensor:$source, + I32Attr:$targetOffset, + I32Attr:$sourceOffset, + I32Attr:$size + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getTargetMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $target `,` $source `)` attr-dict `:` `(` type($target) `,` type($source) `)` `->` type($output) + }]; +} + +//===----------------------------------------------------------------------===// +// Math +//===----------------------------------------------------------------------===// + +def PimTransposeOp : PimOp<"transpose", [DestinationStyleOpInterface]> { + let summary = "Transpose a matrix"; + + let arguments = (ins + PimTensor:$input, + I64ArrayAttr:$permutation, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + +def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> { + let summary = "Vector-matrix multiplication: c = a * b"; + + let arguments = (ins + I32Attr:$weightIndex, + PimTensor:$input, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + +def PimMVMOp : PimOp<"mvm", [DestinationStyleOpInterface]> { + let summary = "Matrix-vector multiplication: c = a * b"; + + let arguments = (ins + I32Attr:$weightIndex, + PimTensor:$input, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + +def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> { + let summary = "Element-wise addition: c = a + b"; + + let arguments = (ins + PimTensor:$lhs, + PimTensor:$rhs, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + +def PimVVSubOp : PimOp<"vvsub", [DestinationStyleOpInterface]> { + let summary = "Element-wise subtraction: c = a - b"; + + let arguments = (ins + PimTensor:$lhs, + PimTensor:$rhs, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + +def PimVVMulOp : PimOp<"vvmul", [DestinationStyleOpInterface]> { + let summary = "Element-wise multiplication: c = a * b"; + + let arguments = (ins + PimTensor:$lhs, + PimTensor:$rhs, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + +def PimVVMaxOp : PimOp<"vvmax", [DestinationStyleOpInterface]> { + let summary = "Element-wise max: c = max(a, b)"; + + let arguments = (ins + PimTensor:$lhs, + PimTensor:$rhs, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + +def PimVVDMulOp : PimOp<"vvdmul", [DestinationStyleOpInterface]> { + let summary = "Dot product: c = dot(a, b)"; + + let arguments = (ins + PimTensor:$lhs, + PimTensor:$rhs, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + +def PimSumOp : PimOp<"sum", [DestinationStyleOpInterface]> { + let summary = "Reduce all elements to a single value"; + + let arguments = (ins + PimTensor:$input, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + +def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> { + let summary = "Average all elements into a single value"; + + let arguments = (ins + PimTensor:$input, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + +def PimVReluOp : PimOp<"vrelu", [DestinationStyleOpInterface]> { + let summary = "Element-wise ReLU: c = max(a, 0)"; + + let arguments = (ins + PimTensor:$input, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + +def PimVTanhOp : PimOp<"vtanh", [DestinationStyleOpInterface]> { + let summary = "Element-wise tanh activation"; + + let arguments = (ins + PimTensor:$input, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + +def PimVSigmOp : PimOp<"vsigm", [DestinationStyleOpInterface]> { + let summary = "Element-wise sigmoid activation"; + + let arguments = (ins + PimTensor:$input, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + #endif // PIM_DIALECT_H diff --git a/src/PIM/Dialect/Pim/PimOps.cpp b/src/PIM/Dialect/Pim/PimOps.cpp index 8e9d2f6..1c59c9a 100644 --- a/src/PIM/Dialect/Pim/PimOps.cpp +++ b/src/PIM/Dialect/Pim/PimOps.cpp @@ -25,19 +25,6 @@ void PimDialect::initialize() { >(); } -#define POPULATE_DEPENDENCIES(OP_NAME) \ - void OP_NAME::populateDependencies(bufferization::RegisterDependenciesFn registerDependenciesFn) { \ - registerDependenciesFn(this->getOutBuf(), this->getResult()); \ - } - -POPULATE_DEPENDENCIES(PimVVDMulOp) -POPULATE_DEPENDENCIES(PimApplyFiltersOp) -POPULATE_DEPENDENCIES(PimSumOp) -POPULATE_DEPENDENCIES(PimVAvgOp) -POPULATE_DEPENDENCIES(PimVReluOp) -POPULATE_DEPENDENCIES(PimVTanhOp) -POPULATE_DEPENDENCIES(PimVSigmOp) - } // namespace pim } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 55d886f..f150af0 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -30,7 +30,7 @@ static Value materializeContiguousMemRef(Value memrefValue, Location loc, Rewrit rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(sizeInBytes)) - .getDstOut(); + .getOutput(); } struct MemCopyHostToDevOpInterface @@ -40,26 +40,26 @@ struct MemCopyHostToDevOpInterface const BufferizationOptions& options, BufferizationState& state) const { auto memCopyHostToDevOp = cast(op); - auto deviceDst = memCopyHostToDevOp.getDeviceDst(); - auto hostSrc = memCopyHostToDevOp.getHostSrc(); + auto deviceTarget = memCopyHostToDevOp.getDeviceTarget(); + auto hostSource = memCopyHostToDevOp.getHostSource(); - auto deviceDstOpt = getBuffer(rewriter, deviceDst, options, state); - if (failed(deviceDstOpt)) + auto deviceTargetOpt = getBuffer(rewriter, deviceTarget, options, state); + if (failed(deviceTargetOpt)) return failure(); - auto deviceDstMemRef = *deviceDstOpt; + auto deviceTargetMemRef = *deviceTargetOpt; - auto hostSrcOpt = getBuffer(rewriter, hostSrc, options, state); - if (failed(hostSrcOpt)) + auto hostSourceOpt = getBuffer(rewriter, hostSource, options, state); + if (failed(hostSourceOpt)) return failure(); - auto hostSrcMemRef = *hostSrcOpt; + auto hostSourceMemRef = *hostSourceOpt; replaceOpWithNewBufferizedOp(rewriter, memCopyHostToDevOp, - deviceDstMemRef.getType(), - deviceDstMemRef, - hostSrcMemRef, - memCopyHostToDevOp.getDeviceDstOffsetAttr(), - memCopyHostToDevOp.getHostSrcOffsetAttr(), + deviceTargetMemRef.getType(), + deviceTargetMemRef, + hostSourceMemRef, + memCopyHostToDevOp.getDeviceTargetOffsetAttr(), + memCopyHostToDevOp.getHostSourceOffsetAttr(), memCopyHostToDevOp.getSizeAttr()); return success(); } @@ -73,25 +73,25 @@ struct MemCopyDevToHostOpInterface BufferizationState& state) const { auto memCopyDevToHostOp = cast(op); - auto globalDst = memCopyDevToHostOp.getHostDst(); - auto globalDstOpt = getBuffer(rewriter, globalDst, options, state); - if (failed(globalDstOpt)) + auto hostTarget = memCopyDevToHostOp.getHostTarget(); + auto hostTargetOpt = getBuffer(rewriter, hostTarget, options, state); + if (failed(hostTargetOpt)) return failure(); - auto globalDstMemRef = *globalDstOpt; + auto hostTargetMemRef = *hostTargetOpt; - auto localSrc = memCopyDevToHostOp.getDeviceSrc(); - auto localSrcOpt = getBuffer(rewriter, localSrc, options, state); - if (failed(localSrcOpt)) + auto deviceSource = memCopyDevToHostOp.getDeviceSource(); + auto deviceSourceOpt = getBuffer(rewriter, deviceSource, options, state); + if (failed(deviceSourceOpt)) return failure(); - auto localSrcMemRef = *localSrcOpt; + auto deviceSourceMemRef = *deviceSourceOpt; replaceOpWithNewBufferizedOp(rewriter, memCopyDevToHostOp, - globalDstMemRef.getType(), - globalDstMemRef, - localSrcMemRef, - memCopyDevToHostOp.getHostDstOffsetAttr(), - memCopyDevToHostOp.getDeviceSrcOffsetAttr(), + hostTargetMemRef.getType(), + hostTargetMemRef, + deviceSourceMemRef, + memCopyDevToHostOp.getHostTargetOffsetAttr(), + memCopyDevToHostOp.getDeviceSourceOffsetAttr(), memCopyDevToHostOp.getSizeAttr()); return success(); } @@ -109,16 +109,16 @@ struct TransposeOpBufferizeInterface BufferizationState& state) const { auto transposeOp = cast(op); - auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state); - if (failed(dataOpt)) + auto inputOpt = getBuffer(rewriter, transposeOp.getInput(), options, state); + if (failed(inputOpt)) return failure(); - auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state); - if (failed(outBufOpt)) + auto outputBufferOpt = getBuffer(rewriter, transposeOp.getOutputBuffer(), options, state); + if (failed(outputBufferOpt)) return failure(); replaceOpWithNewBufferizedOp( - rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt); + rewriter, op, outputBufferOpt->getType(), *inputOpt, transposeOp.getPermutation(), *outputBufferOpt); return success(); } }; @@ -132,9 +132,9 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel(op); Value readVal = uRead->get(); Value writeVal = uWrite->get(); - if (writeVal != vmmOp.getOutBuf()) + if (writeVal != vmmOp.getOutputBuffer()) return false; - if (readVal == vmmOp.getVectorInput()) + if (readVal == vmmOp.getInput()) if (state.areEquivalentBufferizedValues(readVal, writeVal)) return true; return false; @@ -146,16 +146,16 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel(op); - auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state); - if (failed(vectorInputOpt)) + auto inputOpt = getBuffer(rewriter, vmmOp.getInput(), options, state); + if (failed(inputOpt)) return failure(); - auto outBufOpt = getBuffer(rewriter, vmmOp.getOutBuf(), options, state); - if (failed(outBufOpt)) + auto outputBufferOpt = getBuffer(rewriter, vmmOp.getOutputBuffer(), options, state); + if (failed(outputBufferOpt)) return failure(); replaceOpWithNewBufferizedOp( - rewriter, op, outBufOpt->getType(), vmmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt); + rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt); return success(); } }; @@ -171,16 +171,16 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel(op); - auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state); - if (failed(vectorInputOpt)) + auto inputOpt = getBuffer(rewriter, mvmOp.getInput(), options, state); + if (failed(inputOpt)) return failure(); - auto outBufOpt = getBuffer(rewriter, mvmOp.getOutBuf(), options, state); - if (failed(outBufOpt)) + auto outputBufferOpt = getBuffer(rewriter, mvmOp.getOutputBuffer(), options, state); + if (failed(outputBufferOpt)) return failure(); replaceOpWithNewBufferizedOp( - rewriter, op, outBufOpt->getType(), mvmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt); + rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt); return success(); } }; @@ -203,22 +203,23 @@ struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel(op); - auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state); - if (failed(aOpt)) + auto lhsOpt = getBuffer(rewriter, binaryOp.getLhs(), options, state); + if (failed(lhsOpt)) return failure(); - auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state); - if (failed(bOpt)) + auto rhsOpt = getBuffer(rewriter, binaryOp.getRhs(), options, state); + if (failed(rhsOpt)) return failure(); - auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state); - if (failed(outBufOpt)) + auto outputBufferOpt = getBuffer(rewriter, binaryOp.getOutputBuffer(), options, state); + if (failed(outputBufferOpt)) return failure(); - Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter); - Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter); + Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter); + Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter); - replaceOpWithNewBufferizedOp(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt); + replaceOpWithNewBufferizedOp( + rewriter, op, outputBufferOpt->getType(), contiguousLhs, contiguousRhs, *outputBufferOpt); return success(); } }; diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td index bc920e3..f5040a4 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td @@ -16,4 +16,5 @@ def memrefCopyToPimMemCopyOp : Pat< (returnType $dst)) >; + #endif // PIM_BUFFERIZATION diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 5397ec4..84de4a2 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -16,7 +16,7 @@ class SpatOp traits = []> : Op; // TODO maybe remove and use AnyRankedTensor directly -def SpatTensor: +def SpatTensor : AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; class SpatType traits = []> @@ -28,8 +28,12 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> { let summary = "Virtual channel type"; } -def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { - let summary = "Compute operation, with constant weights already attached"; +//===----------------------------------------------------------------------===// +// Execution +//===----------------------------------------------------------------------===// + +def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { + let summary = "Compute region with attached constant weights"; let arguments = (ins Variadic:$weights, @@ -49,7 +53,9 @@ def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegment }]; } -def SpatYieldOp: SpatOp<"yield", [Terminator]> { +def SpatYieldOp : SpatOp<"yield", [Terminator]> { + let summary = "Yield results from a compute region"; + let arguments = (ins Variadic:$outputs ); @@ -60,12 +66,14 @@ def SpatYieldOp: SpatOp<"yield", [Terminator]> { } //===----------------------------------------------------------------------===// -// Data movement operations +// Communication //===----------------------------------------------------------------------===// -def SpatChannelNewOp: SpatOp<"channel_new", []> { +def SpatChannelNewOp : SpatOp<"channel_new", []> { + let summary = "Create a new virtual channel"; + let results = (outs - SpatChannelType:$new_channel + SpatChannelType:$channel ); let builders = [ @@ -79,108 +87,74 @@ def SpatChannelNewOp: SpatOp<"channel_new", []> { }]; } -def SpatChannelSendOp: SpatOp<"channel_send", []> { +def SpatChannelSendOp : SpatOp<"channel_send", []> { + let summary = "Send a tensor through a channel"; + let arguments = (ins - SpatChannelType: $channel, - SpatTensor: $data + SpatChannelType:$channel, + SpatTensor:$input ); let assemblyFormat = [{ - $data `to` $channel attr-dict `:` `(` type($data) `->` type($channel) `)` + $input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)` }]; } -def SpatChannelReceiveOp: SpatOp<"channel_receive", []> { +def SpatChannelReceiveOp : SpatOp<"channel_receive", []> { + let summary = "Receive a tensor from a channel"; + let arguments = (ins - SpatChannelType: $channel + SpatChannelType:$channel ); let results = (outs - SpatTensor: $data + SpatTensor:$output ); let assemblyFormat = [{ - $channel attr-dict `:` `(` type($channel) `->` type($data) `)` + $channel attr-dict `:` `(` type($channel) `->` type($output) `)` }]; } def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> { + let summary = "Broadcast a tensor through a shared channel buffer"; + let arguments = (ins - SpatChannelType: $channel, - SpatTensor: $data + SpatChannelType:$channel, + SpatTensor:$input ); + + let assemblyFormat = [{ + $input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)` + }]; } def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> { + let summary = "Receive a tensor from a shared channel buffer"; + let arguments = (ins - SpatChannelType: $channel + SpatChannelType:$channel ); let results = (outs - SpatTensor: $data + SpatTensor:$output ); -} -//===----------------------------------------------------------------------===// -// Math operations -//===----------------------------------------------------------------------===// - -def SpatConstantOp: SpatOp<"constant", []> { - let description = [{ - "Constant value, should be used for weights and biases" + let assemblyFormat = [{ + $channel attr-dict `:` `(` type($channel) `->` type($output) `)` }]; - - let arguments = (ins - AnyAttr: $value, - BoolAttr: $shouldAllocate - ); - - let results = (outs - SpatTensor: $out - ); } -def SpatWeightedVMMOp: SpatOp<"Wvmm", []> { - let summary = "Vector-matrix-Multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute."; +//===----------------------------------------------------------------------===// +// Math +//===----------------------------------------------------------------------===// + +def SpatWeightedVMMOp : SpatOp<"Wvmm", []> { + let summary = "Vector-matrix multiplication within a weighted compute operation"; let arguments = (ins - I32Attr: $weightIndex, - SpatTensor:$vector - ); - - let results = (outs - SpatTensor:$output - ); - - // TODO: Verifier that checks it is within a WeightedCompute operation, - // that the weightIndex is valid, and that the matrix is of the right size. - let hasVerifier = 1; -} - -def SpatWeightedMVMOp: SpatOp<"Wmvm", []> { - let summary = "Matrix-vector multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute."; - - let arguments = (ins - I32Attr: $weightIndex, - SpatTensor:$vector - ); - - let results = (outs - SpatTensor:$output - ); - - // TODO: Verifier that checks it is within a WeightedCompute operation, - // that the weightIndex is valid, and that the matrix is of the right size. - let hasVerifier = 1; -} - - -def SpatVAddOp: SpatOp<"vadd", []> { - let summary = "Element-wise add between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1"; - - let arguments = (ins - SpatTensor: $a, - SpatTensor: $b + I32Attr:$weightIndex, + SpatTensor:$input ); let results = (outs @@ -190,76 +164,68 @@ def SpatVAddOp: SpatOp<"vadd", []> { let hasVerifier = 1; let assemblyFormat = [{ - $a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output) + `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } -def SpatVMulOp: SpatOp<"vmul", []> { - let summary = "Element-wise multiplication between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1"; - - let arguments = (ins - SpatTensor: $a, - SpatTensor: $b - ); - - let results = (outs - SpatTensor:$output - ); - - //let hasVerifier = 1; - - let assemblyFormat = [{ - $a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output) - }]; -} - -def SpatVDivOp: SpatOp<"vdiv", []> { - let summary = "Element-wise division between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1"; +def SpatWeightedMVMOp : SpatOp<"Wmvm", []> { + let summary = "Matrix-vector multiplication within a weighted compute operation"; let arguments = (ins - SpatTensor:$a, - SpatTensor:$b + I32Attr:$weightIndex, + SpatTensor:$input ); let results = (outs SpatTensor:$output ); - //let hasVerifier = 1; + let hasVerifier = 1; let assemblyFormat = [{ - $a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output) + `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } -//TODO: remove -def SpatVSDivOp: SpatOp<"vsdiv", []> { - - let summary = "Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)"; +def SpatVAddOp : SpatOp<"vadd", []> { + let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1"; let arguments = (ins - SpatTensor:$dividend, - SpatTensor:$divisor + SpatTensor:$lhs, + SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) + }]; } -def SpatSumOp: SpatOp<"sum", []> { - let summary = "Sum all the elements in the input tensors into a single scalar wrapped in tensor for convenience"; +def SpatVMulOp : SpatOp<"vmul", []> { + let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1"; let arguments = (ins - SpatTensor: $input + SpatTensor:$lhs, + SpatTensor:$rhs ); let results = (outs SpatTensor:$output ); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) + }]; } -def SpatSigmoidOp: SpatOp<"sigmoid", []> { +def SpatSumOp : SpatOp<"sum", []> { + let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor"; + let arguments = (ins SpatTensor:$input ); @@ -267,9 +233,15 @@ def SpatSigmoidOp: SpatOp<"sigmoid", []> { let results = (outs SpatTensor:$output ); + + let assemblyFormat = [{ + `(` $input `)` attr-dict `:` type($input) `->` type($output) + }]; } -def SpatReluOp: SpatOp<"relu", []> { +def SpatSigmoidOp : SpatOp<"sigmoid", []> { + let summary = "Element-wise sigmoid activation"; + let arguments = (ins SpatTensor:$input ); @@ -277,68 +249,34 @@ def SpatReluOp: SpatOp<"relu", []> { let results = (outs SpatTensor:$output ); -} - -def SpatVMaxOp: SpatOp<"vmax", []> { - - let summary = "Element-wise max function"; - - let arguments = (ins - SpatTensor: $a, - SpatTensor: $b - ); - - let results = (outs - SpatTensor:$output - ); - - let hasVerifier = 1; -} - -def SpatApplyFiltersOp : SpatOp<"apply_filters", []> { - let summary = "Apply multiple crossbar weights to a convolutional input tile."; - let description = [{ - Applies a variable number of crossbar weights to a single large image tensor tile, - producing a corresponding output tile. This essentially encapsulates a big for loop - over all pixels in the input tile, where each pixel is multiplied by all the weights - in the operation. - }]; - - let arguments = (ins - I64ArrayAttr: $weightIndices, - I64ArrayAttr: $xKernelPositions, - I64ArrayAttr: $yKernelPositions, - SpatTensor: $input - ); - let results = (outs SpatTensor); let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type(results) + `(` $input `)` attr-dict `:` type($input) `->` type($output) }]; } -//===----------------------------------------------------------------------===// -// Other operations -//===----------------------------------------------------------------------===// - -def SpatImgConcatOp: SpatOp<"img_concat", []> { - - let summary = "Concatenate pixel tiles into a single image"; - - let description = [{ - Concatenate pixel tiles into a single image: - 1. First, concatenate the pixel tiles along the "channel" axis (axis 1). - 2. Next, concatenate the pixel tiles along the "width" axis (axis 2). - 3. Finally, concatenate the pixel tiles along the "height" axis (axis 3). - - The input tiles should be provided in a specific order: - start from the top left pixel, - then continue with the pixel on its right, - and once you finish the first row of pixels, go to the next row. - }]; +def SpatReluOp : SpatOp<"relu", []> { + let summary = "Element-wise ReLU activation"; let arguments = (ins - Variadic:$inputs + SpatTensor:$input + ); + + let results = (outs + SpatTensor:$output + ); + + let assemblyFormat = [{ + `(` $input `)` attr-dict `:` type($input) `->` type($output) + }]; +} + +def SpatVMaxOp : SpatOp<"vmax", []> { + let summary = "Element-wise max between two tensors"; + + let arguments = (ins + SpatTensor:$lhs, + SpatTensor:$rhs ); let results = (outs @@ -347,9 +285,9 @@ def SpatImgConcatOp: SpatOp<"img_concat", []> { let hasVerifier = 1; - let extraClassDeclaration = [{ - mlir::Value getInputTile(size_t x, size_t y, size_t tile); + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) }]; } -#endif // SPATIAL_DIALECT_H \ No newline at end of file +#endif // SPATIAL_DIALECT_H diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index de1c1be..7023b35 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -136,7 +136,7 @@ LogicalResult SpatWeightedMVMOp::verify() { if (failed(matrixShapeOpt)) return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op"); auto matrixShape = *matrixShapeOpt; - auto vectorShape = getVector().getType().getShape(); + auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); /* Two possible accepted shapes: @@ -157,7 +157,7 @@ LogicalResult SpatWeightedVMMOp::verify() { if (failed(matrixShapeOpt)) return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op"); auto matrixShape = *matrixShapeOpt; - auto vectorShape = getVector().getType().getShape(); + auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); /* Accepted shape: @@ -200,59 +200,6 @@ LogicalResult SpatVMaxOp::verify() { return OpTrait::impl::verifySameOperandsAndResultType(*this); } -LogicalResult SpatImgConcatOp::verify() { - auto imgShape = mlir::cast(getType()); - size_t img_w = getImageWidth(imgShape); - size_t img_h = getImageHeight(imgShape); - size_t img_c = getImageChannel(imgShape); - - size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue()); - size_t channelTileRest = img_c % crossbarSize; - - auto operands = getOperands(); - - // Check number of operands - if (img_w * img_h * channelTiles != operands.size()) - return emitError("Number of operands does not match output image size"); - - // For each output pixel, check that the inputTiles have a correct shape - for (size_t x = 0; x < img_w; x++) { - for (size_t y = 0; y < img_h; y++) { - size_t channel_counts = 0; - for (size_t t = 0; t < channelTiles; t++) { - auto inputShape = mlir::cast(getInputTile(x, y, t).getType()); - if (!inputShape) - return emitError("Invalid input type, must be ShapedType"); - - // N == W == H == 1 - if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1) - return emitError("Invalid input shape: N,W,H must all be 1"); - - size_t inputChannels = getImageChannel(inputShape); - - // Check the number of channels in this tile are correct: - // - CASE1: last tile of pixel, if there is some rest it must match that - // - CASE2: common case, the channel count is exactly the crossbarSize - if (t == channelTiles - 1 && channelTileRest != 0) { - if (inputChannels != channelTileRest) - return emitError("Invalid channel count for last tile of pixel"); - } - else { - if (inputChannels != crossbarSize) - return emitError("Invalid channel count for some pixel tile"); - } - - channel_counts += inputChannels; - } - - if (channel_counts != img_c) - emitError("Invalid number of channels for some pixel"); - } - } - - return success(); -} - LogicalResult SpatWeightedCompute::verify() { // Check that it has a terminator, it is a yieldOp, and it has a single // operand with the same type as the result @@ -308,22 +255,6 @@ LogicalResult SpatWeightedCompute::verify() { return success(); } -Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) { - auto operands = getOperands(); - auto imgShape = mlir::cast(getType()); - size_t img_w = getImageWidth(imgShape); - size_t img_h = getImageHeight(imgShape); - size_t img_c = getImageChannel(imgShape); - - size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue()); - - assert(tile < channelTiles); - assert(x < img_w); - assert(y < img_h); - - return operands[tile + x * channelTiles + y * img_w * channelTiles]; -} - } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp index 304ef64..a27cc89 100644 --- a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp @@ -53,7 +53,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(sizeInBytes)) - .getDstOut(); + .getOutput(); } const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id"); @@ -202,7 +202,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa memrefOperands.push_back(outputTensor); - Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes(); + Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput(); replaceOpWithBufferizedValues(rewriter, op, newValue); @@ -243,7 +243,7 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod cast(op).getWeightIndexAttr(), memrefOperand, outputTensor) - .getOutRes(); + .getOutput(); replaceOpWithBufferizedValues(rewriter, op, newValue); @@ -289,7 +289,7 @@ struct ChannelReceiveOpInterface outputTensor, rewriter.getI32IntegerAttr(numElements * elementSize), rewriter.getI32IntegerAttr(srcCoreId.value())) - .getOut(); + .getOutput(); replaceOpWithBufferizedValues(rewriter, op, newValue); @@ -403,7 +403,7 @@ struct ChannelBroadcastReceiveOpInterface rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputSize)); - replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst()); + replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getOutput()); return success(); } @@ -490,84 +490,6 @@ struct SumOpInterface : VariadicArgumentElementWiseOpInterface {}; -// Create a new bufferizable op interface for the apply filters operation. -struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel { - - // One operand ($input) is read from. All other inputs are only written to. - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - - // Operand 0: $input - // Operand 1: $outBuf - // Operand 2: $accumBuf - return opOperand.getOperandNumber() == 0; - } - - // One input ($accumBuf) is written to. All other inputs are only read. - bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - - // Operand 0: $input - // Operand 1: $outBuf - // Operand 2: $accumBuf - return opOperand.getOperandNumber() == 2; - } - - // No operands are aliased with any other operands. - AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - return {}; - } - - // Bufferize the operation. - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - - // Get the input tensor buffer. - auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state); - - if (failed(inputBuffer)) - return failure(); - - // Create a new buffer for the output tensor. - auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); - - // Create a new buffer for the accumulation buffer. - // To do this, create a new allocation operation. Size must be axbx1x1, - // where axbxcxd is the size of the output tensor. Since the shape is - // different, we can't immediately use createEmptyFromType, we first need to - // create the shape of the accumulation buffer. - auto accumShape = llvm::to_vector<4>(cast(op->getResult(0).getType()).getShape()); - - // Set the last two dimensions to 1. - accumShape[accumShape.size() - 1] = 1; - accumShape[accumShape.size() - 2] = 1; - - auto accumType = MemRefType::get(accumShape, cast(op->getResult(0).getType()).getElementType()); - - auto accumBuffer = createEmptyFromType(accumType, op->getLoc(), rewriter); - - // Bufferize the operation. - auto weightIndices = cast(op).getWeightIndicesAttr(); - auto xKernelPositions = cast(op).getXKernelPositionsAttr(); - auto yKernelPositions = cast(op).getYKernelPositionsAttr(); - - Value bufferized = pim::PimApplyFiltersOp::create(rewriter, - op->getLoc(), - outputTensor.getType(), - weightIndices, - xKernelPositions, - yKernelPositions, - *inputBuffer, - outputTensor, - accumBuffer); - - // Replace the operation with the bufferized value. - replaceOpWithBufferizedValues(rewriter, op, bufferized); - - return success(); - } -}; - void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) { SpatWeightedCompute::attachInterface(*ctx); @@ -580,7 +502,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { SpatChannelSendOp::attachInterface(*ctx); SpatChannelBroadcastReceiveOp::attachInterface(*ctx); SpatChannelBroadcastSendOp::attachInterface(*ctx); - SpatApplyFiltersOp::attachInterface(*ctx); }); } diff --git a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp index d11b4c0..0c2e207 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp +++ b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp @@ -247,11 +247,11 @@ struct FoldConstantTransposePattern final : OpRewritePattern(transposeOp.getOutRes().getType()); + auto resultType = dyn_cast(transposeOp.getOutput().getType()); if (!resultType || !resultType.hasStaticShape()) return failure(); - auto sourceGetGlobal = transposeOp.getData().getDefiningOp(); + auto sourceGetGlobal = transposeOp.getInput().getDefiningOp(); if (!sourceGetGlobal) return failure(); @@ -268,8 +268,8 @@ struct FoldConstantTransposePattern final : OpRewritePattern perms; - perms.reserve(transposeOp.getPerms().size()); - for (IntegerAttr attr : transposeOp.getPerms().getAsRange()) + perms.reserve(transposeOp.getPermutation().size()); + for (IntegerAttr attr : transposeOp.getPermutation().getAsRange()) perms.push_back(attr.getInt()); FailureOr transposedAttr = transposeDenseElements(denseAttr, perms); if (failed(transposedAttr)) @@ -389,18 +389,18 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { if (copyOp->getParentOfType()) return failure(); - auto allocOp = copyOp.getDst().getDefiningOp(); + auto allocOp = copyOp.getTarget().getDefiningOp(); if (!allocOp) return failure(); auto allocType = dyn_cast(allocOp.getType()); if (!allocType || !allocType.hasStaticShape()) return failure(); - if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0) + if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0) return failure(); - auto srcSubview = getStaticSubviewInfo(copyOp.getSrc()); - Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc()); + auto srcSubview = getStaticSubviewInfo(copyOp.getSource()); + Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource()); auto moduleOp = copyOp->getParentOfType(); if (!moduleOp) diff --git a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp index 772f57b..6ee63a4 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp +++ b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp @@ -89,10 +89,10 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern auto status = rewriteSubviewCopyLikeOp(copyOp, - copyOp.getDst(), - copyOp.getSrc(), - copyOp.getDstOffset(), - copyOp.getSrcOffset(), + copyOp.getTarget(), + copyOp.getSource(), + copyOp.getTargetOffset(), + copyOp.getSourceOffset(), copyOp.getSize(), rewriter, [&](MemRefType resultType, @@ -114,7 +114,7 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern if (failed(status)) return failure(); - rewriter.replaceOp(copyOp, copyOp.getDst()); + rewriter.replaceOp(copyOp, copyOp.getTarget()); return success(); } }; @@ -125,10 +125,10 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern