fix pool lowering
Validate Operations / validate-operations (push) Has been cancelled

better reports (dcp merge and memory)
This commit is contained in:
NiccoloN
2026-05-12 12:32:23 +02:00
parent 8ad504fcdf
commit 80a7298552
8 changed files with 393 additions and 203 deletions
+25 -16
View File
@@ -33,12 +33,14 @@ inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimite
return parser.parseOptionalRParen(); return parser.parseOptionalRParen();
} }
inline void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { template <typename StreamT>
printer << (delimiter == ListDelimiter::Square ? "[" : "("); inline void printOpenDelimiter(StreamT& stream, ListDelimiter delimiter) {
stream << (delimiter == ListDelimiter::Square ? "[" : "(");
} }
inline void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { template <typename StreamT>
printer << (delimiter == ListDelimiter::Square ? "]" : ")"); inline void printCloseDelimiter(StreamT& stream, ListDelimiter delimiter) {
stream << (delimiter == ListDelimiter::Square ? "]" : ")");
} }
template <typename EntryT, typename ParseEntryFn> template <typename EntryT, typename ParseEntryFn>
@@ -163,8 +165,8 @@ inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, Prin
} }
} }
template <typename IntT> template <typename StreamT, typename IntT>
inline void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef<IntT> values, ListDelimiter delimiter) { inline void printCompressedIntegerEntries(StreamT& stream, ArrayRef<IntT> values) {
struct FlatCompression { struct FlatCompression {
enum class Kind { enum class Kind {
Single, Single,
@@ -271,41 +273,48 @@ inline void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef<IntT>
return std::pair(bestLength, bestRepeatCount); return std::pair(bestLength, bestRepeatCount);
}; };
printOpenDelimiter(printer, delimiter);
for (size_t index = 0; index < values.size();) { for (size_t index = 0; index < values.size();) {
if (index != 0) if (index != 0)
printer << ", "; stream << ", ";
FlatCompression flat = computeFlatCompression(index); FlatCompression flat = computeFlatCompression(index);
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index); auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount; size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) { if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren); printOpenDelimiter(stream, ListDelimiter::Paren);
printer << " x" << sublistRepeatCount; printCompressedIntegerEntries(stream, values.slice(index, sublistLength));
printCloseDelimiter(stream, ListDelimiter::Paren);
stream << " x" << sublistRepeatCount;
index += repeatedSublistCoverage; index += repeatedSublistCoverage;
continue; continue;
} }
switch (flat.kind) { switch (flat.kind) {
case FlatCompression::Kind::Progression: case FlatCompression::Kind::Progression:
printer << flat.firstValue << " to " << flat.lastValue; stream << flat.firstValue << " to " << flat.lastValue;
if (flat.step != 1) if (flat.step != 1)
printer << " by " << flat.step; stream << " by " << flat.step;
if (flat.repeatCount > 1) if (flat.repeatCount > 1)
printer << " x" << flat.repeatCount; stream << " x" << flat.repeatCount;
index += flat.covered; index += flat.covered;
break; break;
case FlatCompression::Kind::EqualRun: case FlatCompression::Kind::EqualRun:
printer << flat.firstValue << " x" << flat.repeatCount; stream << flat.firstValue << " x" << flat.repeatCount;
index += flat.covered; index += flat.covered;
break; break;
case FlatCompression::Kind::Single: case FlatCompression::Kind::Single:
printer << flat.firstValue; stream << flat.firstValue;
index += flat.covered; index += flat.covered;
break; break;
} }
} }
printCloseDelimiter(printer, delimiter); }
template <typename StreamT, typename IntT>
inline void printCompressedIntegerSequence(StreamT& stream, ArrayRef<IntT> values, ListDelimiter delimiter) {
printOpenDelimiter(stream, delimiter);
printCompressedIntegerEntries(stream, values);
printCloseDelimiter(stream, delimiter);
} }
template <typename IntT> template <typename IntT>
+72 -19
View File
@@ -25,6 +25,7 @@
#include <utility> #include <utility>
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Common/IR/CompactAsmUtils.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp" #include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" #include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
@@ -36,6 +37,7 @@
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
using namespace onnx_mlir; using namespace onnx_mlir;
using namespace onnx_mlir::compact_asm;
static size_t getValueSizeInBytes(mlir::Value value) { static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType()); auto type = cast<ShapedType>(value.getType());
@@ -125,26 +127,29 @@ std::string formatMemory(uint64_t bytes) {
return rss.str(); return rss.str();
} }
void PimMemory::report(llvm::raw_ostream& file) { static void printMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
uint64_t numAlloca = 0; os << "\tNumber of allocas: " << row.numAlloca << "\n";
uint64_t sizeAlloca = 0; os << "\tAllocated memory: " << formatMemory(row.sizeAlloca) << "\n";
uint64_t numGlobal = 0; os << "\tNumber of globals: " << row.numGlobal << "\n";
uint64_t sizeGlobal = 0; os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n";
}
MemoryReportRow PimMemory::getReportRow() const {
MemoryReportRow row;
for (auto& [val, memEntry] : globalMemEntriesMap) { for (auto& [val, memEntry] : globalMemEntriesMap) {
if (auto op = val.getDefiningOp()) { if (auto op = val.getDefiningOp()) {
if (auto allocaOp = dyn_cast<memref::AllocOp>(op)) { if (isa<memref::AllocOp>(op)) {
numAlloca++; row.numAlloca++;
sizeAlloca += memEntry.size; row.sizeAlloca += memEntry.size;
} }
if (auto allocaOp = dyn_cast<memref::GetGlobalOp>(op)) { if (isa<memref::GetGlobalOp>(op)) {
numGlobal++; row.numGlobal++;
sizeGlobal += memEntry.size; row.sizeGlobal += memEntry.size;
} }
} }
} }
return row;
file << numAlloca << " " << formatMemory(sizeAlloca) << " " << numGlobal << " " << formatMemory(sizeGlobal) << "\n";
} }
void PimMemory::remove(mlir::Value val) { void PimMemory::remove(mlir::Value val) {
@@ -193,17 +198,64 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu
} }
void PimAcceleratorMemory::reportHost() { void PimAcceleratorMemory::reportHost() {
llvm::raw_os_ostream os(fileReport); hostReportRow = hostMem.getReportRow();
os << "Host Memory\t";
hostMem.report(os);
os.flush();
} }
void PimAcceleratorMemory::reportCore(size_t coreId) { void PimAcceleratorMemory::reportCore(size_t coreId) {
coreReportRows.push_back({coreId, deviceMem.at(coreId).getReportRow()});
}
void PimAcceleratorMemory::flushReport() {
if (!fileReport.is_open())
return;
llvm::raw_os_ostream os(fileReport); llvm::raw_os_ostream os(fileReport);
os << "Core " << coreId << " Memory\t"; if (hostReportRow.has_value()) {
deviceMem.at(coreId).report(os); os << "Host:\n";
printMemoryReportRow(os, *hostReportRow);
}
if (!coreReportRows.empty()) {
if (hostReportRow.has_value())
os << "\n";
llvm::stable_sort(coreReportRows, [](const auto& lhs, const auto& rhs) {
const MemoryReportRow& lhsRow = lhs.second;
const MemoryReportRow& rhsRow = rhs.second;
if (lhsRow.sizeAlloca != rhsRow.sizeAlloca)
return lhsRow.sizeAlloca > rhsRow.sizeAlloca;
if (lhsRow.numAlloca != rhsRow.numAlloca)
return lhsRow.numAlloca > rhsRow.numAlloca;
if (lhsRow.sizeGlobal != rhsRow.sizeGlobal)
return lhsRow.sizeGlobal > rhsRow.sizeGlobal;
if (lhsRow.numGlobal != rhsRow.numGlobal)
return lhsRow.numGlobal > rhsRow.numGlobal;
return lhs.first < rhs.first;
});
for (size_t index = 0; index < coreReportRows.size();) {
size_t runEnd = index + 1;
while (runEnd < coreReportRows.size() && coreReportRows[runEnd].second == coreReportRows[index].second)
++runEnd;
llvm::SmallVector<size_t, 8> coreIds;
coreIds.reserve(runEnd - index);
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
coreIds.push_back(coreReportRows[coreIndex].first);
os << "Core ";
printCompressedIntegerEntries(os, ArrayRef<size_t>(coreIds));
os << ":\n";
printMemoryReportRow(os, coreReportRows[index].second);
if (runEnd < coreReportRows.size())
os << "\n";
index = runEnd;
}
}
os.flush(); os.flush();
fileReport.close();
} }
void PimAcceleratorMemory::clean(mlir::Operation* op) { void PimAcceleratorMemory::clean(mlir::Operation* op) {
@@ -867,5 +919,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
} }
} }
memory.flushReport();
return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath); return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath);
} }
+17 -1
View File
@@ -8,6 +8,7 @@
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
#include <fstream> #include <fstream>
#include <optional>
#include "onnx-mlir/Compiler/OMCompilerTypes.h" #include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -20,6 +21,18 @@ struct MemEntry {
size_t size; size_t size;
}; };
struct MemoryReportRow {
uint64_t numAlloca = 0;
uint64_t sizeAlloca = 0;
uint64_t numGlobal = 0;
uint64_t sizeGlobal = 0;
bool operator==(const MemoryReportRow& other) const {
return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal
&& sizeGlobal == other.sizeGlobal;
}
};
class PimMemory { class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries; llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap; llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
@@ -37,7 +50,7 @@ public:
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp); void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
void allocateCore(mlir::Operation* op); void allocateCore(mlir::Operation* op);
void report(llvm::raw_ostream& os); MemoryReportRow getReportRow() const;
void remove(mlir::Value val); void remove(mlir::Value val);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; } size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
@@ -52,6 +65,8 @@ public:
private: private:
llvm::SmallDenseMap<size_t, PimMemory> deviceMem; llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
std::fstream fileReport; std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<std::pair<size_t, MemoryReportRow>, 32> coreReportRows;
public: public:
PimAcceleratorMemory() PimAcceleratorMemory()
@@ -72,6 +87,7 @@ public:
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const; size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
void reportHost(); void reportHost();
void reportCore(size_t coreId); void reportCore(size_t coreId);
void flushReport();
void clean(mlir::Operation* op); void clean(mlir::Operation* op);
}; };
-13
View File
@@ -1,16 +1,3 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
//===------------------------- PimCompilerOptions.cpp --------------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// =============================================================================
//
// Compiler Options for PIM
//
//===----------------------------------------------------------------------===//
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions" #define DEBUG_TYPE "PimCompilerOptions"
@@ -1,9 +1,12 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include <algorithm> #include <algorithm>
#include <optional> #include <optional>
@@ -30,16 +33,6 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue; return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
} }
template <typename PoolOp>
static FailureOr<Value> concatAlongAxis(
ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
if (values.empty()) {
poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty");
return failure();
}
return createSpatConcat(rewriter, loc, axis, values);
}
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) { static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType()); auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType()); Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
@@ -54,34 +47,126 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides); return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
} }
template <typename ReduceOp> static Value createPoolFillElement(
static FailureOr<Value> ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation* op, ArrayRef<Value> windowValues) { if (!useMinimumValue)
if (windowValues.empty()) { return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
op->emitOpError("pool window resolved to zero valid elements");
if (auto floatType = dyn_cast<FloatType>(elementType)) {
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue));
}
if (auto integerType = dyn_cast<IntegerType>(elementType)) {
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue));
}
llvm_unreachable("unsupported pool element type");
}
static Value createPoolFillTensor(
ConversionPatternRewriter& rewriter, Location loc, RankedTensorType tensorType, bool useMinimumValue) {
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue);
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement);
}
template <typename PoolOp>
static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value input,
RankedTensorType inputType,
int64_t padTop,
int64_t padLeft,
int64_t padBottom,
int64_t padRight) {
if (padTop == 0 && padLeft == 0 && padBottom == 0 && padRight == 0)
return input;
auto paddedType = RankedTensorType::get({inputType.getDimSize(0),
inputType.getDimSize(1),
inputType.getDimSize(2) + padTop + padBottom,
inputType.getDimSize(3) + padLeft + padRight},
inputType.getElementType(),
inputType.getEncoding());
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padTop),
rewriter.getIndexAttr(padLeft)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padBottom),
rewriter.getIndexAttr(padRight)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, input, lowPads, highPads);
auto* padBlock = new Block();
for (int index = 0; index < paddedType.getRank(); ++index)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
Value padValue = createPoolFillElement(
rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
tensor::YieldOp::create(rewriter, loc, padValue);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewriter,
Location loc,
Operation* op,
RankedTensorType outType,
int64_t channels,
int64_t inputHeight,
int64_t inputWidth,
int64_t outputHeight,
int64_t outputWidth,
int64_t kernelHeight,
int64_t kernelWidth,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
int64_t padTop,
int64_t padLeft,
bool countIncludePad) {
auto elemType = dyn_cast<FloatType>(outType.getElementType());
if (!elemType) {
op->emitOpError("AveragePool lowering requires a floating-point element type");
return failure(); return failure();
} }
Value reduced = windowValues.front(); auto scaleType = RankedTensorType::get({1, channels, outputHeight, outputWidth}, elemType, outType.getEncoding());
for (Value value : windowValues.drop_front()) SmallVector<Attribute> scaleValues;
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value); scaleValues.reserve(static_cast<size_t>(channels * outputHeight * outputWidth));
return reduced; for (int64_t channel = 0; channel < channels; ++channel) {
(void) channel;
for (int64_t outH = 0; outH < outputHeight; ++outH) {
for (int64_t outW = 0; outW < outputWidth; ++outW) {
int64_t validCount = 0;
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
if (inH < 0 || inH >= inputHeight)
continue;
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (inW < 0 || inW >= inputWidth)
continue;
++validCount;
}
} }
static FailureOr<Value> scaleAverageWindow( const int64_t divisor = countIncludePad ? kernelHeight * kernelWidth : validCount;
ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
if (divisor <= 0) { if (divisor <= 0) {
op->emitOpError("AveragePool divisor must be positive"); op->emitOpError("AveragePool divisor must be positive");
return failure(); return failure();
} }
if (divisor == 1) scaleValues.push_back(rewriter.getFloatAttr(elemType, 1.0 / static_cast<double>(divisor)));
return reducedWindow; }
}
}
auto tileType = cast<RankedTensorType>(reducedWindow.getType()); auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
double scale = 1.0 / static_cast<double>(divisor); return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult();
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor).getResult();
} }
template <typename PoolOp> template <typename PoolOp>
@@ -159,49 +244,90 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
} }
} }
(void) padBottom;
(void) padRight;
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue()); const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize; const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
const int64_t outputPatchCount = batchSize * outputHeight * outputWidth;
const bool countIncludePad = [&]() {
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>)
return poolOp.getCountIncludePad() == 1;
return true;
}();
Value averageScaleTensor;
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
auto maybeAverageScaleTensor = createAverageScaleTensor(rewriter,
loc,
poolOp,
outType,
channels,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kernelHeight,
kernelWidth,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
padTop,
padLeft,
countIncludePad);
if (failed(maybeAverageScaleTensor))
return failure();
averageScaleTensor = *maybeAverageScaleTensor;
}
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
auto computeOp = auto computeOp =
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult { createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
SmallVector<Value> batchResults; Value paddedInput = createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
batchResults.reserve(batchSize); Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
for (int64_t batch = 0; batch < batchSize; ++batch) { Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
SmallVector<Value> rows; Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
rows.reserve(outputHeight); Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount);
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth);
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth);
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
for (int64_t outH = 0; outH < outputHeight; ++outH) { auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
SmallVector<Value> rowPixels; rewriter.setInsertionPointToStart(outputLoop.getBody());
rowPixels.reserve(outputWidth);
for (int64_t outW = 0; outW < outputWidth; ++outW) { Value outputPatchIndex = outputLoop.getInductionVar();
SmallVector<Value> outputChannelTiles; Value pooledOutputAcc = outputLoop.getRegionIterArgs().front();
outputChannelTiles.reserve(channelTileCount);
Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
Value updatedOutput = pooledOutputAcc;
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize); const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
Value reducedWindow = createPoolFillTensor(
rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
SmallVector<Value> windowValues;
windowValues.reserve(kernelHeight * kernelWidth);
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop; Value paddedInH = windowBaseH;
if (inH < 0 || inH >= inputHeight) if (kernelH * dilationHeight != 0) {
continue; Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight);
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
}
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft; Value paddedInW = windowBaseW;
if (inW < 0 || inW >= inputWidth) if (kernelW * dilationWidth != 0) {
continue; Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth);
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
}
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch), SmallVector<OpFoldResult> offsets = {batchIndex,
rewriter.getIndexAttr(channelTile * xbarSize), rewriter.getIndexAttr(channelTile * xbarSize),
rewriter.getIndexAttr(inH), paddedInH,
rewriter.getIndexAttr(inW)}; paddedInW};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels), rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1),
@@ -211,54 +337,51 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1)};
Value windowValue = Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, xArg, offsets, sizes, strides); tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue); windowValue = materializeContiguousTile(rewriter, loc, windowValue);
windowValues.push_back(windowValue); reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
} }
} }
if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
auto reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, poolOp, windowValues);
if (failed(reducedWindow))
return failure();
Value reducedWindowValue = *reducedWindow;
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) { if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1; SmallVector<OpFoldResult> scaleOffsets = {rewriter.getIndexAttr(0),
const int64_t divisor = rewriter.getIndexAttr(channelTile * xbarSize),
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size()); outHeightIndex,
auto scaledWindow = scaleAverageWindow(rewriter, loc, poolOp, reducedWindowValue, divisor); outWidthIndex};
if (failed(scaledWindow)) SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
return failure(); rewriter.getIndexAttr(tileChannels),
reducedWindowValue = *scaledWindow; rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> scaleStrides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value scaleSlice = tensor::ExtractSliceOp::create(
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
} }
outputChannelTiles.push_back(reducedWindowValue); SmallVector<OpFoldResult> outputOffsets = {batchIndex,
rewriter.getIndexAttr(channelTile * xbarSize),
outHeightIndex,
outWidthIndex};
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> outputStrides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
updatedOutput = tensor::InsertSliceOp::create(
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
} }
auto rowPixel = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/1, outputChannelTiles); scf::YieldOp::create(rewriter, loc, updatedOutput);
if (failed(rowPixel))
return failure();
rowPixels.push_back(*rowPixel);
}
auto row = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/3, rowPixels); rewriter.setInsertionPointAfter(outputLoop);
if (failed(row)) spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0));
return failure();
rows.push_back(*row);
}
auto batchResult = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/2, rows);
if (failed(batchResult))
return failure();
batchResults.push_back(*batchResult);
}
auto pooledOutput = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/0, batchResults);
if (failed(pooledOutput))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, *pooledOutput);
return success(); return success();
}); });
if (failed(computeOp)) if (failed(computeOp))
+2 -3
View File
@@ -3,7 +3,6 @@
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
@@ -460,11 +459,11 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("all outputs must have the same type"); return emitError("all outputs must have the same type");
} }
if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdsAttrName)) { if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr); auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
if (!coreIdsAttr) if (!coreIdsAttr)
return emitError("compute_batch coreIds attribute must be a dense i32 array"); return emitError("compute_batch coreIds attribute must be a dense i32 array");
if (coreIdsAttr.size() != laneCountSz) if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
return emitError("compute_batch coreIds array length must match laneCount"); return emitError("compute_batch coreIds array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; })) if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
return emitError("compute_batch coreIds values must be positive"); return emitError("compute_batch coreIds values must be positive");
@@ -1,6 +1,5 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include <algorithm>
#include <vector> #include <vector>
#include "GraphSupport.hpp" #include "GraphSupport.hpp"
@@ -31,7 +30,7 @@ llvm::DenseSet<TaskDCP*> collectReachableTasks(TaskDCP* root, bool followParents
} }
GraphDCP::CandidateRelations computeCandidateRelations(TaskDCP* candidate) { GraphDCP::CandidateRelations computeCandidateRelations(TaskDCP* candidate) {
return {collectReachableTasks(candidate, true), collectReachableTasks(candidate, false)}; return {collectReachableTasks(candidate, true), collectReachableTasks(candidate, false), {}};
} }
LocalScheduleSnapshot captureLocalScheduleState(TaskDCP* task, LocalScheduleSnapshot captureLocalScheduleState(TaskDCP* task,
@@ -38,6 +38,7 @@
#include "DCPGraph/DCPAnalysis.hpp" #include "DCPGraph/DCPAnalysis.hpp"
#include "RegularOpCompaction.hpp" #include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
@@ -45,6 +46,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
using namespace onnx_mlir::compact_asm;
using SpatCompute = spatial::SpatCompute; using SpatCompute = spatial::SpatCompute;
using SpatComputeBatch = spatial::SpatComputeBatch; using SpatComputeBatch = spatial::SpatComputeBatch;
@@ -766,10 +768,10 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
if (outputDir.empty()) if (outputDir.empty())
return; return;
std::string dialectsDir = outputDir + "/dcp_graph"; std::string reportsDir = outputDir + "/reports";
createDirectory(dialectsDir); createDirectory(reportsDir);
std::fstream file(dialectsDir + "/" + name + ".txt", std::ios::out); std::fstream file(reportsDir + "/" + name + ".txt", std::ios::out);
llvm::raw_os_ostream os(file); llvm::raw_os_ostream os(file);
struct ReportRow { struct ReportRow {
@@ -778,41 +780,42 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
uint64_t weightCount = 0; uint64_t weightCount = 0;
uint64_t instructionCount = 0; uint64_t instructionCount = 0;
bool isRebatched = false; bool isRebatched = false;
SmallVector<int32_t> coreIds;
}; };
uint64_t totalComputeOps = 0; uint64_t totalComputeOps = 0;
uint64_t totalLogicalComputes = 0; uint64_t totalLogicalComputes = 0;
uint64_t totalBatchComputeOps = 0; uint64_t totalBatchComputeOps = 0;
uint64_t totalMultiLaneBatchComputeOps = 0;
std::vector<ReportRow> collectedData; std::vector<ReportRow> collectedData;
for (Operation& op : funcOp.getBody().front()) { for (Operation& op : funcOp.getBody().front()) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) { if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
uint64_t numInst = 0; uint64_t numInst = 0;
for (auto& _ : spatCompute.getRegion().front()) for (auto& _ : spatCompute.getRegion().front())
numInst++; ++numInst;
collectedData.push_back({totalComputeOps++, 1, spatCompute.getWeights().size(), numInst, false}); collectedData.push_back({totalComputeOps++, 1, spatCompute.getWeights().size(), numInst, false, {}});
totalLogicalComputes += 1; totalLogicalComputes += 1;
continue; continue;
} }
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) { if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
uint64_t numInst = 0; uint64_t numInst = 0;
for (auto& _ : batch.getRegion().front()) for (auto& _ : batch.getRegion().front())
numInst++; ++numInst;
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount()); uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
collectedData.push_back({totalComputeOps++, logicalCount, batch.getWeights().size(), numInst, true}); SmallVector<int32_t> coreIds;
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
llvm::append_range(coreIds, coreIdsAttr.asArrayRef());
collectedData.push_back({totalComputeOps++, logicalCount, batch.getWeights().size(), numInst, true, coreIds});
totalLogicalComputes += logicalCount; totalLogicalComputes += logicalCount;
totalBatchComputeOps += 1; totalBatchComputeOps += 1;
if (batch.getLaneCount() > 1)
totalMultiLaneBatchComputeOps += 1;
} }
} }
os << "Used CPUs: " << usedCpuCount << "\n"; os << "Used cores: " << usedCpuCount << "\n";
os << "Number of top-level compute ops: " << totalComputeOps << "\n"; os << "Number of top-level compute ops: " << totalComputeOps << "\n";
os << "Number of logical computes: " << totalLogicalComputes << "\n"; os << "Number of logical computes: " << totalLogicalComputes << "\n";
os << "Number of top-level batch compute ops: " << totalBatchComputeOps << "\n"; os << "Number of top-level batch compute ops: " << totalBatchComputeOps << "\n";
os << "Number of top-level multi-lane batch compute ops: " << totalMultiLaneBatchComputeOps << "\n\n"; os << "\n";
std::stable_sort(collectedData.begin(), collectedData.end(), [](const ReportRow& lft, const ReportRow& rgt) { std::stable_sort(collectedData.begin(), collectedData.end(), [](const ReportRow& lft, const ReportRow& rgt) {
if (lft.isRebatched != rgt.isRebatched) if (lft.isRebatched != rgt.isRebatched)
@@ -855,31 +858,32 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
break; break;
} }
os << (current.isRebatched ? "Batch " : "Compute ") << current.opId; if (current.isRebatched) {
auto expectedPrintedValue = current.opId + 1; os << "Batch ";
bool rangePrinted = false; for (uint64_t index = cI; index <= lastIndex; ++index) {
cI++; if (index != cI)
for (; cI <= lastIndex; ++cI) { os << ",\n ";
auto candidateToPrint = collectedData[cI].opId; os << collectedData[index].opId << " (cores ";
if (candidateToPrint == expectedPrintedValue) { if (collectedData[index].coreIds.empty())
expectedPrintedValue = candidateToPrint + 1; os << "unknown";
rangePrinted = true; else
printCompressedIntegerEntries(os, ArrayRef<int32_t>(collectedData[index].coreIds));
os << ")";
}
} }
else { else {
if (rangePrinted) os << "Compute ";
os << " - " << expectedPrintedValue - 1; SmallVector<uint64_t> opIds;
os << " , " << candidateToPrint; opIds.reserve(lastIndex - cI + 1);
rangePrinted = false; for (uint64_t index = cI; index <= lastIndex; ++index)
expectedPrintedValue = candidateToPrint + 1; opIds.push_back(collectedData[index].opId);
printCompressedIntegerEntries(os, ArrayRef<uint64_t>(opIds));
} }
}
if (rangePrinted && current.opId != expectedPrintedValue - 1)
os << " - " << expectedPrintedValue - 1;
os << ":\n"; os << ":\n";
os << "\tNumber of logical computes " << current.logicalComputeCount << "\n"; os << "\tNumber of logical computes: " << current.logicalComputeCount << "\n";
os << "\tNumber of instructions " << current.instructionCount << "\n"; os << "\tNumber of instructions: " << current.instructionCount << "\n";
os << "\tNumber of used crossbars " << current.weightCount << "\n"; os << "\tNumber of used crossbars: " << current.weightCount << "\n";
cI = lastIndex; cI = lastIndex;
} }
@@ -1438,7 +1442,7 @@ public:
return; return;
} }
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged"); dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
generateReport(func, "spatial1_dcp_merged_report", analysisResult.cpuToLastComputeMap.size()); generateReport(func, "dcp_merge_report", analysisResult.cpuToLastComputeMap.size());
} }
private: private: