diff --git a/src/PIM/Common/IR/CompactAsmUtils.hpp b/src/PIM/Common/IR/CompactAsmUtils.hpp index 6bada90..814f9d4 100644 --- a/src/PIM/Common/IR/CompactAsmUtils.hpp +++ b/src/PIM/Common/IR/CompactAsmUtils.hpp @@ -33,12 +33,14 @@ inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimite return parser.parseOptionalRParen(); } -inline void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { - printer << (delimiter == ListDelimiter::Square ? "[" : "("); +template +inline void printOpenDelimiter(StreamT& stream, ListDelimiter delimiter) { + stream << (delimiter == ListDelimiter::Square ? "[" : "("); } -inline void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { - printer << (delimiter == ListDelimiter::Square ? "]" : ")"); +template +inline void printCloseDelimiter(StreamT& stream, ListDelimiter delimiter) { + stream << (delimiter == ListDelimiter::Square ? "]" : ")"); } template @@ -163,8 +165,8 @@ inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, Prin } } -template -inline void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef values, ListDelimiter delimiter) { +template +inline void printCompressedIntegerEntries(StreamT& stream, ArrayRef values) { struct FlatCompression { enum class Kind { Single, @@ -271,41 +273,48 @@ inline void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef return std::pair(bestLength, bestRepeatCount); }; - printOpenDelimiter(printer, delimiter); for (size_t index = 0; index < values.size();) { if (index != 0) - printer << ", "; + stream << ", "; FlatCompression flat = computeFlatCompression(index); auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index); size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount; if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) { - printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren); - printer << " x" << sublistRepeatCount; + printOpenDelimiter(stream, ListDelimiter::Paren); + printCompressedIntegerEntries(stream, values.slice(index, sublistLength)); + printCloseDelimiter(stream, ListDelimiter::Paren); + stream << " x" << sublistRepeatCount; index += repeatedSublistCoverage; continue; } switch (flat.kind) { case FlatCompression::Kind::Progression: - printer << flat.firstValue << " to " << flat.lastValue; + stream << flat.firstValue << " to " << flat.lastValue; if (flat.step != 1) - printer << " by " << flat.step; + stream << " by " << flat.step; if (flat.repeatCount > 1) - printer << " x" << flat.repeatCount; + stream << " x" << flat.repeatCount; index += flat.covered; break; case FlatCompression::Kind::EqualRun: - printer << flat.firstValue << " x" << flat.repeatCount; + stream << flat.firstValue << " x" << flat.repeatCount; index += flat.covered; break; case FlatCompression::Kind::Single: - printer << flat.firstValue; + stream << flat.firstValue; index += flat.covered; break; } } - printCloseDelimiter(printer, delimiter); +} + +template +inline void printCompressedIntegerSequence(StreamT& stream, ArrayRef values, ListDelimiter delimiter) { + printOpenDelimiter(stream, delimiter); + printCompressedIntegerEntries(stream, values); + printCloseDelimiter(stream, delimiter); } template diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 0bf1d39..6894558 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -25,6 +25,7 @@ #include #include "Common/PimCommon.hpp" +#include "Common/IR/CompactAsmUtils.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp" #include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" @@ -36,6 +37,7 @@ using namespace llvm; using namespace mlir; using namespace onnx_mlir; +using namespace onnx_mlir::compact_asm; static size_t getValueSizeInBytes(mlir::Value value) { auto type = cast(value.getType()); @@ -125,26 +127,29 @@ std::string formatMemory(uint64_t bytes) { return rss.str(); } -void PimMemory::report(llvm::raw_ostream& file) { - uint64_t numAlloca = 0; - uint64_t sizeAlloca = 0; - uint64_t numGlobal = 0; - uint64_t sizeGlobal = 0; +static void printMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) { + os << "\tNumber of allocas: " << row.numAlloca << "\n"; + os << "\tAllocated memory: " << formatMemory(row.sizeAlloca) << "\n"; + os << "\tNumber of globals: " << row.numGlobal << "\n"; + os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n"; +} + +MemoryReportRow PimMemory::getReportRow() const { + MemoryReportRow row; for (auto& [val, memEntry] : globalMemEntriesMap) { if (auto op = val.getDefiningOp()) { - if (auto allocaOp = dyn_cast(op)) { - numAlloca++; - sizeAlloca += memEntry.size; + if (isa(op)) { + row.numAlloca++; + row.sizeAlloca += memEntry.size; } - if (auto allocaOp = dyn_cast(op)) { - numGlobal++; - sizeGlobal += memEntry.size; + if (isa(op)) { + row.numGlobal++; + row.sizeGlobal += memEntry.size; } } } - - file << numAlloca << " " << formatMemory(sizeAlloca) << " " << numGlobal << " " << formatMemory(sizeGlobal) << "\n"; + return row; } void PimMemory::remove(mlir::Value val) { @@ -193,17 +198,64 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu } void PimAcceleratorMemory::reportHost() { - llvm::raw_os_ostream os(fileReport); - os << "Host Memory\t"; - hostMem.report(os); - os.flush(); + hostReportRow = hostMem.getReportRow(); } void PimAcceleratorMemory::reportCore(size_t coreId) { + coreReportRows.push_back({coreId, deviceMem.at(coreId).getReportRow()}); +} + +void PimAcceleratorMemory::flushReport() { + if (!fileReport.is_open()) + return; + llvm::raw_os_ostream os(fileReport); - os << "Core " << coreId << " Memory\t"; - deviceMem.at(coreId).report(os); + if (hostReportRow.has_value()) { + os << "Host:\n"; + printMemoryReportRow(os, *hostReportRow); + } + + if (!coreReportRows.empty()) { + if (hostReportRow.has_value()) + os << "\n"; + + llvm::stable_sort(coreReportRows, [](const auto& lhs, const auto& rhs) { + const MemoryReportRow& lhsRow = lhs.second; + const MemoryReportRow& rhsRow = rhs.second; + if (lhsRow.sizeAlloca != rhsRow.sizeAlloca) + return lhsRow.sizeAlloca > rhsRow.sizeAlloca; + if (lhsRow.numAlloca != rhsRow.numAlloca) + return lhsRow.numAlloca > rhsRow.numAlloca; + if (lhsRow.sizeGlobal != rhsRow.sizeGlobal) + return lhsRow.sizeGlobal > rhsRow.sizeGlobal; + if (lhsRow.numGlobal != rhsRow.numGlobal) + return lhsRow.numGlobal > rhsRow.numGlobal; + return lhs.first < rhs.first; + }); + + for (size_t index = 0; index < coreReportRows.size();) { + size_t runEnd = index + 1; + while (runEnd < coreReportRows.size() && coreReportRows[runEnd].second == coreReportRows[index].second) + ++runEnd; + + llvm::SmallVector coreIds; + coreIds.reserve(runEnd - index); + for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex) + coreIds.push_back(coreReportRows[coreIndex].first); + + os << "Core "; + printCompressedIntegerEntries(os, ArrayRef(coreIds)); + os << ":\n"; + printMemoryReportRow(os, coreReportRows[index].second); + if (runEnd < coreReportRows.size()) + os << "\n"; + + index = runEnd; + } + } + os.flush(); + fileReport.close(); } void PimAcceleratorMemory::clean(mlir::Operation* op) { @@ -867,5 +919,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: } } + memory.flushReport(); return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath); } diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index d03704b..4a6e0f5 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -8,6 +8,7 @@ #include "llvm/Support/raw_os_ostream.h" #include +#include #include "onnx-mlir/Compiler/OMCompilerTypes.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" @@ -20,6 +21,18 @@ struct MemEntry { size_t size; }; +struct MemoryReportRow { + uint64_t numAlloca = 0; + uint64_t sizeAlloca = 0; + uint64_t numGlobal = 0; + uint64_t sizeGlobal = 0; + + bool operator==(const MemoryReportRow& other) const { + return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal + && sizeGlobal == other.sizeGlobal; + } +}; + class PimMemory { llvm::SmallVector, 32> memEntries; llvm::SmallDenseMap& globalMemEntriesMap; @@ -37,7 +50,7 @@ public: void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp); void allocateCore(mlir::Operation* op); - void report(llvm::raw_ostream& os); + MemoryReportRow getReportRow() const; void remove(mlir::Value val); size_t getFirstAvailableAddress() const { return firstAvailableAddress; } @@ -52,6 +65,8 @@ public: private: llvm::SmallDenseMap deviceMem; std::fstream fileReport; + std::optional hostReportRow; + llvm::SmallVector, 32> coreReportRows; public: PimAcceleratorMemory() @@ -72,6 +87,7 @@ public: size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const; void reportHost(); void reportCore(size_t coreId); + void flushReport(); void clean(mlir::Operation* op); }; diff --git a/src/PIM/Compiler/PimCompilerOptions.cpp b/src/PIM/Compiler/PimCompilerOptions.cpp index 1b3288a..ad9bc05 100644 --- a/src/PIM/Compiler/PimCompilerOptions.cpp +++ b/src/PIM/Compiler/PimCompilerOptions.cpp @@ -1,16 +1,3 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -//===------------------------- PimCompilerOptions.cpp --------------------===// -// -// Copyright 2022 The IBM Research Authors. -// -// ============================================================================= -// -// Compiler Options for PIM -// -//===----------------------------------------------------------------------===// #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #define DEBUG_TYPE "PimCompilerOptions" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp index dbee081..d05016e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp @@ -1,9 +1,12 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" #include #include @@ -30,16 +33,6 @@ static int64_t getOptionalI64(std::optional arrayAttr, size_t index, return arrayAttr ? getI64(*arrayAttr, index) : defaultValue; } -template -static FailureOr concatAlongAxis( - ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef values) { - if (values.empty()) { - poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty"); - return failure(); - } - return createSpatConcat(rewriter, loc, axis, values); -} - static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) { auto tileType = cast(tile.getType()); Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType()); @@ -54,34 +47,126 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides); } -template -static FailureOr -reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation* op, ArrayRef windowValues) { - if (windowValues.empty()) { - op->emitOpError("pool window resolved to zero valid elements"); - return failure(); +static Value createPoolFillElement( + ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) { + if (!useMinimumValue) + return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType)); + + if (auto floatType = dyn_cast(elementType)) { + auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true); + return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue)); } - Value reduced = windowValues.front(); - for (Value value : windowValues.drop_front()) - reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value); - return reduced; + if (auto integerType = dyn_cast(elementType)) { + auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth()); + return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue)); + } + + llvm_unreachable("unsupported pool element type"); } -static FailureOr scaleAverageWindow( - ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) { - if (divisor <= 0) { - op->emitOpError("AveragePool divisor must be positive"); +static Value createPoolFillTensor( + ConversionPatternRewriter& rewriter, Location loc, RankedTensorType tensorType, bool useMinimumValue) { + auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue); + return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement); +} + +template +static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter, + Location loc, + PoolOp poolOp, + Value input, + RankedTensorType inputType, + int64_t padTop, + int64_t padLeft, + int64_t padBottom, + int64_t padRight) { + if (padTop == 0 && padLeft == 0 && padBottom == 0 && padRight == 0) + return input; + + auto paddedType = RankedTensorType::get({inputType.getDimSize(0), + inputType.getDimSize(1), + inputType.getDimSize(2) + padTop + padBottom, + inputType.getDimSize(3) + padLeft + padRight}, + inputType.getElementType(), + inputType.getEncoding()); + SmallVector lowPads = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(padTop), + rewriter.getIndexAttr(padLeft)}; + SmallVector highPads = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(padBottom), + rewriter.getIndexAttr(padRight)}; + auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, input, lowPads, highPads); + auto* padBlock = new Block(); + for (int index = 0; index < paddedType.getRank(); ++index) + padBlock->addArgument(rewriter.getIndexType(), loc); + padOp.getRegion().push_back(padBlock); + rewriter.setInsertionPointToStart(padBlock); + Value padValue = createPoolFillElement( + rewriter, loc, inputType.getElementType(), std::is_same_v); + tensor::YieldOp::create(rewriter, loc, padValue); + rewriter.setInsertionPointAfter(padOp); + return padOp.getResult(); +} + +static FailureOr createAverageScaleTensor(ConversionPatternRewriter& rewriter, + Location loc, + Operation* op, + RankedTensorType outType, + int64_t channels, + int64_t inputHeight, + int64_t inputWidth, + int64_t outputHeight, + int64_t outputWidth, + int64_t kernelHeight, + int64_t kernelWidth, + int64_t strideHeight, + int64_t strideWidth, + int64_t dilationHeight, + int64_t dilationWidth, + int64_t padTop, + int64_t padLeft, + bool countIncludePad) { + auto elemType = dyn_cast(outType.getElementType()); + if (!elemType) { + op->emitOpError("AveragePool lowering requires a floating-point element type"); return failure(); } - if (divisor == 1) - return reducedWindow; - auto tileType = cast(reducedWindow.getType()); - double scale = 1.0 / static_cast(divisor); - auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale)); - Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr); - return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor).getResult(); + auto scaleType = RankedTensorType::get({1, channels, outputHeight, outputWidth}, elemType, outType.getEncoding()); + SmallVector scaleValues; + scaleValues.reserve(static_cast(channels * outputHeight * outputWidth)); + for (int64_t channel = 0; channel < channels; ++channel) { + (void) channel; + for (int64_t outH = 0; outH < outputHeight; ++outH) { + for (int64_t outW = 0; outW < outputWidth; ++outW) { + int64_t validCount = 0; + for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { + const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop; + if (inH < 0 || inH >= inputHeight) + continue; + for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { + const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft; + if (inW < 0 || inW >= inputWidth) + continue; + ++validCount; + } + } + + const int64_t divisor = countIncludePad ? kernelHeight * kernelWidth : validCount; + if (divisor <= 0) { + op->emitOpError("AveragePool divisor must be positive"); + return failure(); + } + scaleValues.push_back(rewriter.getFloatAttr(elemType, 1.0 / static_cast(divisor))); + } + } + } + + auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues); + return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult(); } template @@ -159,106 +244,144 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { } } - (void) padBottom; - (void) padRight; - const int64_t xbarSize = static_cast(crossbarSize.getValue()); const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize; + const int64_t outputPatchCount = batchSize * outputHeight * outputWidth; + const bool countIncludePad = [&]() { + if constexpr (std::is_same_v) + return poolOp.getCountIncludePad() == 1; + return true; + }(); + Value averageScaleTensor; + if constexpr (std::is_same_v) { + auto maybeAverageScaleTensor = createAverageScaleTensor(rewriter, + loc, + poolOp, + outType, + channels, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kernelHeight, + kernelWidth, + strideHeight, + strideWidth, + dilationHeight, + dilationWidth, + padTop, + padLeft, + countIncludePad); + if (failed(maybeAverageScaleTensor)) + return failure(); + averageScaleTensor = *maybeAverageScaleTensor; + } constexpr size_t numInputs = 1; auto computeOp = createSpatCompute(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult { - SmallVector batchResults; - batchResults.reserve(batchSize); + Value paddedInput = createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight); + Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()); - for (int64_t batch = 0; batch < batchSize; ++batch) { - SmallVector rows; - rows.reserve(outputHeight); + Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount); + Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth); + Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth); + Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight); + Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth); - for (int64_t outH = 0; outH < outputHeight; ++outH) { - SmallVector rowPixels; - rowPixels.reserve(outputWidth); + auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit}); + rewriter.setInsertionPointToStart(outputLoop.getBody()); - for (int64_t outW = 0; outW < outputWidth; ++outW) { - SmallVector outputChannelTiles; - outputChannelTiles.reserve(channelTileCount); + Value outputPatchIndex = outputLoop.getInductionVar(); + Value pooledOutputAcc = outputLoop.getRegionIterArgs().front(); - for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { - const int64_t tileChannels = std::min(xbarSize, channels - channelTile * xbarSize); - auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); + Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch); + Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch); + Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth); + Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth); + Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight); + Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth); - SmallVector windowValues; - windowValues.reserve(kernelHeight * kernelWidth); - for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { - const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop; - if (inH < 0 || inH >= inputHeight) - continue; + Value updatedOutput = pooledOutputAcc; + for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { + const int64_t tileChannels = std::min(xbarSize, channels - channelTile * xbarSize); + auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); + Value reducedWindow = createPoolFillTensor( + rewriter, loc, tileType, std::is_same_v); - for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { - const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft; - if (inW < 0 || inW >= inputWidth) - continue; - - SmallVector offsets = {rewriter.getIndexAttr(batch), - rewriter.getIndexAttr(channelTile * xbarSize), - rewriter.getIndexAttr(inH), - rewriter.getIndexAttr(inW)}; - SmallVector sizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(tileChannels), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; - SmallVector strides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; - Value windowValue = - tensor::ExtractSliceOp::create(rewriter, loc, tileType, xArg, offsets, sizes, strides); - windowValue = materializeContiguousTile(rewriter, loc, windowValue); - windowValues.push_back(windowValue); - } - } - - if (windowValues.empty()) - return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements."); - - auto reducedWindow = reduceWindowValues(rewriter, loc, poolOp, windowValues); - if (failed(reducedWindow)) - return failure(); - Value reducedWindowValue = *reducedWindow; - if constexpr (std::is_same_v) { - const bool countIncludePad = poolOp.getCountIncludePad() == 1; - const int64_t divisor = - countIncludePad ? kernelHeight * kernelWidth : static_cast(windowValues.size()); - auto scaledWindow = scaleAverageWindow(rewriter, loc, poolOp, reducedWindowValue, divisor); - if (failed(scaledWindow)) - return failure(); - reducedWindowValue = *scaledWindow; - } - - outputChannelTiles.push_back(reducedWindowValue); - } - - auto rowPixel = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/1, outputChannelTiles); - if (failed(rowPixel)) - return failure(); - rowPixels.push_back(*rowPixel); + for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { + Value paddedInH = windowBaseH; + if (kernelH * dilationHeight != 0) { + Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight); + paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset); } - auto row = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/3, rowPixels); - if (failed(row)) - return failure(); - rows.push_back(*row); + for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { + Value paddedInW = windowBaseW; + if (kernelW * dilationWidth != 0) { + Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth); + paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset); + } + + SmallVector offsets = {batchIndex, + rewriter.getIndexAttr(channelTile * xbarSize), + paddedInH, + paddedInW}; + SmallVector sizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tileChannels), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + SmallVector strides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + Value windowValue = + tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides); + windowValue = materializeContiguousTile(rewriter, loc, windowValue); + reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue); + } } - auto batchResult = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/2, rows); - if (failed(batchResult)) - return failure(); - batchResults.push_back(*batchResult); + if constexpr (std::is_same_v) { + SmallVector scaleOffsets = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(channelTile * xbarSize), + outHeightIndex, + outWidthIndex}; + SmallVector scaleSizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tileChannels), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + SmallVector scaleStrides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + Value scaleSlice = tensor::ExtractSliceOp::create( + rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides); + scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice); + reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice); + } + + SmallVector outputOffsets = {batchIndex, + rewriter.getIndexAttr(channelTile * xbarSize), + outHeightIndex, + outWidthIndex}; + SmallVector outputSizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tileChannels), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + SmallVector outputStrides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + updatedOutput = tensor::InsertSliceOp::create( + rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides); } - auto pooledOutput = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/0, batchResults); - if (failed(pooledOutput)) - return failure(); - spatial::SpatYieldOp::create(rewriter, loc, *pooledOutput); + scf::YieldOp::create(rewriter, loc, updatedOutput); + + rewriter.setInsertionPointAfter(outputLoop); + spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0)); return success(); }); if (failed(computeOp)) diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index e240f17..7845827 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -3,7 +3,6 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/LogicalResult.h" @@ -460,11 +459,11 @@ LogicalResult SpatComputeBatch::verify() { return emitError("all outputs must have the same type"); } - if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdsAttrName)) { + if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) { auto coreIdsAttr = dyn_cast(coreIdAttr); if (!coreIdsAttr) return emitError("compute_batch coreIds attribute must be a dense i32 array"); - if (coreIdsAttr.size() != laneCountSz) + if (coreIdsAttr.size() != static_cast(laneCountSz)) return emitError("compute_batch coreIds array length must match laneCount"); if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; })) return emitError("compute_batch coreIds values must be positive"); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp index eefcabe..8d4c1fb 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp @@ -1,6 +1,5 @@ #include "llvm/ADT/STLExtras.h" -#include #include #include "GraphSupport.hpp" @@ -31,7 +30,7 @@ llvm::DenseSet collectReachableTasks(TaskDCP* root, bool followParents } GraphDCP::CandidateRelations computeCandidateRelations(TaskDCP* candidate) { - return {collectReachableTasks(candidate, true), collectReachableTasks(candidate, false)}; + return {collectReachableTasks(candidate, true), collectReachableTasks(candidate, false), {}}; } LocalScheduleSnapshot captureLocalScheduleState(TaskDCP* task, diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index d26799e..c9a4865 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -38,6 +38,7 @@ #include "DCPGraph/DCPAnalysis.hpp" #include "RegularOpCompaction.hpp" +#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" @@ -45,6 +46,7 @@ using namespace mlir; namespace onnx_mlir { namespace { +using namespace onnx_mlir::compact_asm; using SpatCompute = spatial::SpatCompute; using SpatComputeBatch = spatial::SpatComputeBatch; @@ -766,10 +768,10 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu if (outputDir.empty()) return; - std::string dialectsDir = outputDir + "/dcp_graph"; - createDirectory(dialectsDir); + std::string reportsDir = outputDir + "/reports"; + createDirectory(reportsDir); - std::fstream file(dialectsDir + "/" + name + ".txt", std::ios::out); + std::fstream file(reportsDir + "/" + name + ".txt", std::ios::out); llvm::raw_os_ostream os(file); struct ReportRow { @@ -778,41 +780,42 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu uint64_t weightCount = 0; uint64_t instructionCount = 0; bool isRebatched = false; + SmallVector coreIds; }; uint64_t totalComputeOps = 0; uint64_t totalLogicalComputes = 0; uint64_t totalBatchComputeOps = 0; - uint64_t totalMultiLaneBatchComputeOps = 0; std::vector collectedData; for (Operation& op : funcOp.getBody().front()) { if (auto spatCompute = dyn_cast(&op)) { uint64_t numInst = 0; for (auto& _ : spatCompute.getRegion().front()) - numInst++; - collectedData.push_back({totalComputeOps++, 1, spatCompute.getWeights().size(), numInst, false}); + ++numInst; + collectedData.push_back({totalComputeOps++, 1, spatCompute.getWeights().size(), numInst, false, {}}); totalLogicalComputes += 1; continue; } if (auto batch = dyn_cast(&op)) { uint64_t numInst = 0; for (auto& _ : batch.getRegion().front()) - numInst++; + ++numInst; uint64_t logicalCount = static_cast(batch.getLaneCount()); - collectedData.push_back({totalComputeOps++, logicalCount, batch.getWeights().size(), numInst, true}); + SmallVector coreIds; + if (auto coreIdsAttr = batch->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) + llvm::append_range(coreIds, coreIdsAttr.asArrayRef()); + collectedData.push_back({totalComputeOps++, logicalCount, batch.getWeights().size(), numInst, true, coreIds}); totalLogicalComputes += logicalCount; totalBatchComputeOps += 1; - if (batch.getLaneCount() > 1) - totalMultiLaneBatchComputeOps += 1; } } - os << "Used CPUs: " << usedCpuCount << "\n"; + os << "Used cores: " << usedCpuCount << "\n"; os << "Number of top-level compute ops: " << totalComputeOps << "\n"; os << "Number of logical computes: " << totalLogicalComputes << "\n"; os << "Number of top-level batch compute ops: " << totalBatchComputeOps << "\n"; - os << "Number of top-level multi-lane batch compute ops: " << totalMultiLaneBatchComputeOps << "\n\n"; + os << "\n"; std::stable_sort(collectedData.begin(), collectedData.end(), [](const ReportRow& lft, const ReportRow& rgt) { if (lft.isRebatched != rgt.isRebatched) @@ -855,31 +858,32 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu break; } - os << (current.isRebatched ? "Batch " : "Compute ") << current.opId; - auto expectedPrintedValue = current.opId + 1; - bool rangePrinted = false; - cI++; - for (; cI <= lastIndex; ++cI) { - auto candidateToPrint = collectedData[cI].opId; - if (candidateToPrint == expectedPrintedValue) { - expectedPrintedValue = candidateToPrint + 1; - rangePrinted = true; - } - else { - if (rangePrinted) - os << " - " << expectedPrintedValue - 1; - os << " , " << candidateToPrint; - rangePrinted = false; - expectedPrintedValue = candidateToPrint + 1; + if (current.isRebatched) { + os << "Batch "; + for (uint64_t index = cI; index <= lastIndex; ++index) { + if (index != cI) + os << ",\n "; + os << collectedData[index].opId << " (cores "; + if (collectedData[index].coreIds.empty()) + os << "unknown"; + else + printCompressedIntegerEntries(os, ArrayRef(collectedData[index].coreIds)); + os << ")"; } } - if (rangePrinted && current.opId != expectedPrintedValue - 1) - os << " - " << expectedPrintedValue - 1; + else { + os << "Compute "; + SmallVector opIds; + opIds.reserve(lastIndex - cI + 1); + for (uint64_t index = cI; index <= lastIndex; ++index) + opIds.push_back(collectedData[index].opId); + printCompressedIntegerEntries(os, ArrayRef(opIds)); + } - os << " :\n"; - os << "\tNumber of logical computes " << current.logicalComputeCount << "\n"; - os << "\tNumber of instructions " << current.instructionCount << "\n"; - os << "\tNumber of used crossbars " << current.weightCount << "\n"; + os << ":\n"; + os << "\tNumber of logical computes: " << current.logicalComputeCount << "\n"; + os << "\tNumber of instructions: " << current.instructionCount << "\n"; + os << "\tNumber of used crossbars: " << current.weightCount << "\n"; cI = lastIndex; } @@ -1438,7 +1442,7 @@ public: return; } dumpModule(cast(func->getParentOp()), "spatial1_dcp_merged"); - generateReport(func, "spatial1_dcp_merged_report", analysisResult.cpuToLastComputeMap.size()); + generateReport(func, "dcp_merge_report", analysisResult.cpuToLastComputeMap.size()); } private: