Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-12 10:43:01 +02:00
84 changed files with 4048 additions and 3310 deletions
+17 -36
View File
@@ -72,9 +72,8 @@ inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
} }
template <typename IntT> template <typename IntT>
inline ParseResult parseCompressedIntegerEntries(OpAsmParser& parser, inline ParseResult
ListDelimiter delimiter, parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
SmallVectorImpl<IntT>& values) {
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success(); return success();
@@ -113,8 +112,8 @@ inline ParseResult parseCompressedIntegerEntries(OpAsmParser& parser,
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
} }
if ((last - first) % step != 0) { if ((last - first) % step != 0) {
return parser.emitError( return parser.emitError(parser.getCurrentLocation(),
parser.getCurrentLocation(), "range end must be reachable from start using the given step"); "range end must be reachable from start using the given step");
} }
for (int64_t value = first; value <= last; value += step) for (int64_t value = first; value <= last; value += step)
@@ -140,9 +139,8 @@ inline ParseResult parseCompressedIntegerEntries(OpAsmParser& parser,
} }
template <typename IntT> template <typename IntT>
inline ParseResult parseCompressedIntegerSequence(OpAsmParser& parser, inline ParseResult
ListDelimiter delimiter, parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
SmallVectorImpl<IntT>& values) {
if (parseOpenDelimiter(parser, delimiter)) if (parseOpenDelimiter(parser, delimiter))
return failure(); return failure();
return parseCompressedIntegerEntries(parser, delimiter, values); return parseCompressedIntegerEntries(parser, delimiter, values);
@@ -166,9 +164,7 @@ inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, Prin
} }
template <typename IntT> template <typename IntT>
inline void printCompressedIntegerSequence(OpAsmPrinter& printer, inline void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef<IntT> values, ListDelimiter delimiter) {
ArrayRef<IntT> values,
ListDelimiter delimiter) {
struct FlatCompression { struct FlatCompression {
enum class Kind { enum class Kind {
Single, Single,
@@ -388,9 +384,7 @@ inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, List
printCloseDelimiter(printer, delimiter); printCloseDelimiter(printer, delimiter);
} }
inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
SmallVectorImpl<Type>& types,
bool allowEmpty) {
Type firstType; Type firstType;
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType); OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
if (!firstTypeResult.has_value()) { if (!firstTypeResult.has_value()) {
@@ -422,8 +416,7 @@ inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser,
return success(); return success();
} }
inline ParseResult parseCompressedOperandEntryWithFirst( inline ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
OpAsmParser& parser,
OpAsmParser::UnresolvedOperand firstOperand, OpAsmParser::UnresolvedOperand firstOperand,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) { SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalKeyword("to"))) { if (succeeded(parser.parseOptionalKeyword("to"))) {
@@ -447,8 +440,7 @@ inline ParseResult parseCompressedOperandEntryWithFirst(
return success(); return success();
} }
inline ParseResult parseOneCompressedOperandEntry( inline ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) { SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
OpAsmParser::UnresolvedOperand firstOperand; OpAsmParser::UnresolvedOperand firstOperand;
if (parser.parseOperand(firstOperand)) if (parser.parseOperand(firstOperand))
@@ -474,8 +466,7 @@ inline ParseResult parseCompressedOperandList(OpAsmParser& parser,
} }
} }
inline ParseResult parseCompressedOperandSequence( inline ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) { SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOneCompressedOperandEntry(parser, operands)) if (parseOneCompressedOperandEntry(parser, operands))
return failure(); return failure();
@@ -485,9 +476,7 @@ inline ParseResult parseCompressedOperandSequence(
return success(); return success();
} }
inline ParseResult parseCompressedTypeList(OpAsmParser& parser, inline ParseResult parseCompressedTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
ListDelimiter delimiter,
SmallVectorImpl<Type>& types) {
if (parseOpenDelimiter(parser, delimiter)) if (parseOpenDelimiter(parser, delimiter))
return failure(); return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
@@ -522,10 +511,7 @@ inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
return true; return true;
} }
inline void printValueTupleRun(OpAsmPrinter& printer, inline void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize, ListDelimiter delimiter) {
ValueRange values,
size_t tupleSize,
ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter); printOpenDelimiter(printer, delimiter);
printOpenDelimiter(printer, ListDelimiter::Paren); printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) { for (size_t index = 0; index < tupleSize; ++index) {
@@ -538,10 +524,7 @@ inline void printValueTupleRun(OpAsmPrinter& printer,
printCloseDelimiter(printer, delimiter); printCloseDelimiter(printer, delimiter);
} }
inline void printTypeTupleRun(OpAsmPrinter& printer, inline void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize, ListDelimiter delimiter) {
TypeRange types,
size_t tupleSize,
ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter); printOpenDelimiter(printer, delimiter);
printOpenDelimiter(printer, ListDelimiter::Paren); printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) { for (size_t index = 0; index < tupleSize; ++index) {
@@ -554,8 +537,7 @@ inline void printTypeTupleRun(OpAsmPrinter& printer,
printCloseDelimiter(printer, delimiter); printCloseDelimiter(printer, delimiter);
} }
inline ParseResult parseCompressedOrTupleOperandList( inline ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
OpAsmParser& parser,
ListDelimiter delimiter, ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) { SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter)) if (parseOpenDelimiter(parser, delimiter))
@@ -604,9 +586,8 @@ inline ParseResult parseCompressedOrTupleOperandList(
} }
} }
inline ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, inline ParseResult
ListDelimiter delimiter, parseCompressedOrTupleTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
SmallVectorImpl<Type>& types) {
if (parseOpenDelimiter(parser, delimiter)) if (parseOpenDelimiter(parser, delimiter))
return failure(); return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
+3
View File
@@ -15,7 +15,10 @@ add_pim_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerUtils add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp PimCompilerUtils.cpp
PimArtifactWriter.cpp
PimBatchEmission.cpp
PimCodeGen.cpp PimCodeGen.cpp
PimWeightEmitter.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
+123
View File
@@ -0,0 +1,123 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
#include <cstring>
#include <vector>
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
OnnxMlirCompilerErrorCodes writeHostCoreJson(StringRef outputDirPath) {
std::error_code errorCode;
std::string outputHostCorePath = outputDirPath.str() + "/core_0.json";
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
if (errorCode) {
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
// The host core json contains two no-op-like instructions to satisfy pimsim-nn.
hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
hostFileStream.close();
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
std::error_code errorCode;
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
SmallPtrSet<Operation*, 16> writtenGlobals;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp))
return;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp)
return;
if (!writtenGlobals.insert(globalOp.getOperation()).second)
return;
auto initialValue = globalOp.getInitialValue();
if (!initialValue)
return;
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr)
return;
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
ArrayRef<char> rawData = denseAttr.getRawData();
char* dst = memoryBuffer.data() + memEntry.address;
if (denseAttr.isSplat()) {
size_t elementSize = rawData.size();
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
}
else {
assert(rawData.size() == memEntry.size && "Data size mismatch");
std::memcpy(dst, rawData.data(), rawData.size());
}
});
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
memoryFileStream.close();
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t maxCoreId,
json::Object xbarsPerArrayGroup,
StringRef outputDirPath) {
json::Object configJson;
configJson["core_cnt"] = maxCoreId + 1;
configJson["adc_count"] = 16;
configJson["cell_precision"] = 2;
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
json::Array inputsAddresses;
for (BlockArgument input : funcOp.getArguments())
inputsAddresses.push_back(memory.getValueAddress(input));
configJson["inputs_addresses"] = std::move(inputsAddresses);
json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses);
auto configPath = (outputDirPath + "/config.json").str();
std::error_code errorCode;
raw_fd_ostream jsonOS(configPath, errorCode);
if (errorCode) {
errs() << "Error while opening config file: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
jsonOS << json::Value(std::move(configJson)) << '\n';
jsonOS.close();
return CompilerSuccess;
}
} // namespace onnx_mlir
+26
View File
@@ -0,0 +1,26 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/JSON.h"
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
namespace onnx_mlir {
class PimAcceleratorMemory;
OnnxMlirCompilerErrorCodes writeHostCoreJson(llvm::StringRef outputDirPath);
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
mlir::func::FuncOp funcOp,
PimAcceleratorMemory& memory,
llvm::StringRef outputDirPath);
OnnxMlirCompilerErrorCodes writeConfigJson(mlir::func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t maxCoreId,
llvm::json::Object xbarsPerArrayGroup,
llvm::StringRef outputDirPath);
} // namespace onnx_mlir
+126
View File
@@ -0,0 +1,126 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
} // namespace
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
OpBuilder builder(scratchModule->getContext());
builder.setInsertionPointToStart(scratchModule->getBody());
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
SmallVector<Value> laneWeights;
laneWeights.reserve(weightsPerLane);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
auto coreIds = getBatchCoreIds(coreBatchOp);
auto scalarCore = pim::PimCoreOp::create(
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
IRMapping mapper;
if (coreBatchOp.getBody().front().getNumArguments() == 1)
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) {
if (isa<pim::PimHaltOp>(op)) {
pim::PimHaltOp::create(builder, op.getLoc());
continue;
}
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(builder,
sendBatchOp.getLoc(),
mapper.lookup(sendBatchOp.getInput()),
sendBatchOp.getSizeAttr(),
builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
builder,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(builder,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
mapper.lookup(receiveBatchOp.getOutputBuffer()),
receiveBatchOp.getSizeAttr(),
builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
builder,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
if (!hostSource)
hostSource = memcpBatchOp.getHostSource();
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
mapper.lookup(memcpBatchOp.getDeviceTarget()),
hostSource,
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
continue;
}
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
return callback(scalarCore);
}
} // namespace onnx_mlir
+13
View File
@@ -0,0 +1,13 @@
#pragma once
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
} // namespace onnx_mlir
+23 -375
View File
@@ -5,12 +5,10 @@
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h" #include "mlir/IR/Verifier.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/FileSystem.h" #include "llvm/Support/FileSystem.h"
@@ -21,7 +19,6 @@
#include <absl/types/compare.h> #include <absl/types/compare.h>
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cmath>
#include <cstdint> #include <cstdint>
#include <fstream> #include <fstream>
#include <string> #include <string>
@@ -29,8 +26,11 @@
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm; using namespace llvm;
@@ -42,79 +42,6 @@ static size_t getValueSizeInBytes(mlir::Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8; return type.getNumElements() * type.getElementTypeBitWidth() / 8;
} }
struct DenseWeightView {
DenseElementsAttr denseAttr;
SmallVector<int64_t> shape;
SmallVector<int64_t> strides;
int64_t offset = 0;
};
static SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
strides[index] = strides[index + 1] * shape[index + 1];
return strides;
}
static bool allStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
}
static FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<memref::SubViewOp> subviews;
mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp;
while (true) {
Operation* defOp = current.getDefiningOp();
if (!defOp)
return failure();
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
break;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!allStaticSubviewParts(subview))
return failure();
subviews.push_back(subview);
current = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
current = cast.getSource();
continue;
}
return failure();
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
DenseWeightView view;
view.denseAttr = denseAttr;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStridesForShape(view.shape);
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
}
return view;
}
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType()); auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape()); assert("Only static shape is supported" && type.hasStaticShape());
@@ -723,80 +650,6 @@ static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
return coreLikeOps; return coreLikeOps;
} }
static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
OpBuilder builder(coreBatchOp);
builder.setInsertionPointAfter(coreBatchOp);
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
SmallVector<mlir::Value> laneWeights;
laneWeights.reserve(weightsPerLane);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
auto coreIds = getBatchCoreIds(coreBatchOp);
auto scalarCore = pim::PimCoreOp::create(
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
IRMapping mapper;
if (coreBatchOp.getBody().front().getNumArguments() == 1)
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) {
if (isa<pim::PimHaltOp>(op)) {
pim::PimHaltOp::create(builder, op.getLoc());
continue;
}
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(builder,
sendBatchOp.getLoc(),
mapper.lookup(sendBatchOp.getInput()),
sendBatchOp.getSizeAttr(),
builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(builder,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
mapper.lookup(receiveBatchOp.getOutputBuffer()),
receiveBatchOp.getSizeAttr(),
builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
if (!hostSource)
hostSource = memcpBatchOp.getHostSource();
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
mapper.lookup(memcpBatchOp.getDeviceTarget()),
hostSource,
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
continue;
}
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
return scalarCore;
}
static void aliasMaterializedHostGlobals(ModuleOp moduleOp, static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
func::FuncOp funcOp, func::FuncOp funcOp,
pim::PimCoreOp coreOp, pim::PimCoreOp coreOp,
@@ -822,56 +675,6 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
}); });
} }
/// Write global constant data into a binary memory image at their allocated addresses.
static OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
std::error_code errorCode;
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
SmallPtrSet<Operation*, 16> writtenGlobals;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp))
return;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp)
return;
if (!writtenGlobals.insert(globalOp.getOperation()).second)
return;
auto initialValue = globalOp.getInitialValue();
if (!initialValue)
return;
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr)
return;
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
ArrayRef<char> rawData = denseAttr.getRawData();
char* dst = memoryBuffer.data() + memEntry.address;
if (denseAttr.isSplat()) {
size_t elementSize = rawData.size();
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
}
else {
assert(rawData.size() == memEntry.size && "Data size mismatch");
std::memcpy(dst, rawData.data(), rawData.size());
}
});
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
memoryFileStream.close();
return CompilerSuccess;
}
/// Dispatch all operations in a core region to the appropriate code generator. /// Dispatch all operations in a core region to the appropriate code generator.
/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is /// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is
/// fully resolved before the JSON instructions are emitted. /// fully resolved before the JSON instructions are emitted.
@@ -926,7 +729,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge); coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
else { else {
op.emitError("Unsupported codegen for this operation"); op.emitError("Unsupported codegen for this operation");
op.dump();
return failure(); return failure();
} }
processedOperations++; processedOperations++;
@@ -935,154 +737,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
return failed(result) ? -1 : static_cast<int64_t>(processedOperations); return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
} }
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory");
size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
for (Operation* op : coreLikeOps) {
SmallVector<pim::PimCoreOp> scalarCores;
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
scalarCores.push_back(coreOp);
}
else {
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane));
}
for (pim::PimCoreOp coreOp : scalarCores) {
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto weightView = resolveDenseWeightView(moduleOp, weight);
if (failed(weightView)) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
}
if (mapCoreWeightToFileName[coreId].contains(weight))
continue;
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
mapCoreWeightToFileName[coreId].insert({weight, fileName});
continue;
}
DenseElementsAttr denseAttr = weightView->denseAttr;
ArrayRef<int64_t> shape = weightView->shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
}
}
for (pim::PimCoreOp coreOp : scalarCores)
if (coreOp.getOperation() != op)
coreOp.erase();
}
return mapCoreWeightToFileName;
}
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t maxCoreId,
json::Object xbarsPerArrayGroup,
StringRef outputDirPath) {
json::Object configJson;
// pimsim-nn indexes cores directly by their numeric core ID, with the host
// occupying core 0.
configJson["core_cnt"] = maxCoreId + 1;
// TODO: Should this be based on the floating point type used in the model?
// The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision
// Number of ADC for MVM units
configJson["adc_count"] = 16;
// The bit precision of each ADC
configJson["cell_precision"] = 2;
// Crossbar configuration
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
// Memory layout of inputs and outputs
json::Array inputsAddresses;
for (BlockArgument input : funcOp.getArguments())
inputsAddresses.push_back(memory.getValueAddress(input));
configJson["inputs_addresses"] = std::move(inputsAddresses);
json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses);
auto configPath = (outputDirPath + "/config.json").str();
std::error_code errorCode;
raw_fd_ostream jsonOS(configPath, errorCode);
if (errorCode) {
errs() << "Error while opening config file: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
jsonOS << json::Value(std::move(configJson)) << '\n';
jsonOS.close();
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::string& outputDirPath) { OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::string& outputDirPath) {
if (!outputDirPath.empty()) { if (!outputDirPath.empty()) {
if (auto error = sys::fs::create_directory(outputDirPath)) { if (auto error = sys::fs::create_directory(outputDirPath)) {
@@ -1103,17 +757,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath)) if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
return err; return err;
// Write empty host core file if (auto err = writeHostCoreJson(outputDirPath))
std::error_code errorCode; return err;
auto outputHostCorePath = outputDirPath + "/core_0.json";
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
if (errorCode) {
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
// The host core json contains 2 random instructions, just to make pimsim-nn happy
hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
hostFileStream.close();
// For each core, specify the number of crossbar per array group. // For each core, specify the number of crossbar per array group.
// This implementation always assigns one crossbar per group. // This implementation always assigns one crossbar per group.
@@ -1145,17 +790,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
} }
for (Operation* op : coreLikeOps) { for (Operation* op : coreLikeOps) {
SmallVector<pim::PimCoreOp> scalarCores; auto emitCore = [&](pim::PimCoreOp coreOp, bool temporaryCore) -> OnnxMlirCompilerErrorCodes {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
scalarCores.push_back(coreOp);
}
else {
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane));
}
for (pim::PimCoreOp coreOp : scalarCores) {
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId()); size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
size_t coreId = emittedCoreIds.lookup(originalCoreId); size_t coreId = emittedCoreIds.lookup(originalCoreId);
maxCoreId = std::max(maxCoreId, coreId); maxCoreId = std::max(maxCoreId, coreId);
@@ -1210,12 +845,25 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
} }
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup); xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
if (temporaryCore)
coreOp.walk([&memory](Operation* op) { memory.clean(op); });
return CompilerSuccess;
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
if (auto err = emitCore(coreOp, false))
return err;
continue;
} }
for (pim::PimCoreOp coreOp : scalarCores) auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
if (coreOp.getOperation() != op) { for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
coreOp.walk([&memory](Operation* op) { memory.clean(op); }); OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
coreOp.erase(); if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
laneResult = emitCore(coreOp, true);
return laneResult == CompilerSuccess ? success() : failure();
})))
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
} }
} }
+221
View File
@@ -0,0 +1,221 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
namespace {
struct DenseWeightView {
DenseElementsAttr denseAttr;
SmallVector<int64_t> shape;
SmallVector<int64_t> strides;
int64_t offset = 0;
};
SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
strides[index] = strides[index + 1] * shape[index + 1];
return strides;
}
bool allStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
}
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<memref::SubViewOp> subviews;
mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp;
while (true) {
Operation* defOp = current.getDefiningOp();
if (!defOp)
return failure();
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
break;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!allStaticSubviewParts(subview))
return failure();
subviews.push_back(subview);
current = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
current = cast.getSource();
continue;
}
return failure();
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
DenseWeightView view;
view.denseAttr = denseAttr;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStridesForShape(view.shape);
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
}
return view;
}
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
};
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices);
return indices;
}
SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
return getUsedWeightIndices(coreOp.getBody().front());
}
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front())
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
coreLikeOps.push_back(&op);
return coreLikeOps;
}
} // namespace
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory");
size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
for (Operation* op : coreLikeOps) {
auto processCore = [&](pim::PimCoreOp coreOp) {
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto weightView = resolveDenseWeightView(moduleOp, weight);
if (failed(weightView)) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
}
if (mapCoreWeightToFileName[coreId].contains(weight))
continue;
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
mapCoreWeightToFileName[coreId].insert({weight, fileName});
continue;
}
DenseElementsAttr denseAttr = weightView->denseAttr;
ArrayRef<int64_t> shape = weightView->shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
}
return success();
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
(void) processCore(coreOp);
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore)))
return mapCoreWeightToFileName;
}
return mapCoreWeightToFileName;
}
} // namespace onnx_mlir
+16
View File
@@ -0,0 +1,16 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include <string>
namespace onnx_mlir {
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath);
} // namespace onnx_mlir
@@ -3,6 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen) add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial add_pim_library(OMONNXToSpatial
ConversionPatterns.cpp
HostFoldability.cpp
HostLegality.cpp
PrePatterns.cpp
PostPatterns.cpp
Patterns/Math/Conv.cpp Patterns/Math/Conv.cpp
Patterns/Math/Elementwise.cpp Patterns/Math/Elementwise.cpp
Patterns/Math/Gemm.cpp Patterns/Math/Gemm.cpp
@@ -1,8 +1,7 @@
#pragma once #pragma once
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "ComputeRegionBuilder.hpp" #include "ComputeRegionBuilder.hpp"
#include "ShapeTilingUtils.hpp" #include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp" #include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -5,6 +5,8 @@
#include "ShapeTilingUtils.hpp" #include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
using namespace mlir; using namespace mlir;
@@ -30,10 +32,29 @@ SmallVector<Value> sliceTensor(
for (int64_t i = 0; i < numSlices; i++) { for (int64_t i = 0; i < numSlices; i++) {
offsets[axis] = rewriter.getIndexAttr(i * sliceSize); offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
if (i == numSlices - 1 && lastSliceSize != 0) int64_t currentSliceSize = sliceSize;
if (i == numSlices - 1 && lastSliceSize != 0) {
currentSliceSize = lastSliceSize;
sizes[axis] = rewriter.getIndexAttr(lastSliceSize); sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
}
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides); SmallVector<int64_t> sliceShape(shape.begin(), shape.end());
sliceShape[axis] = currentSliceSize;
auto sliceType =
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
Value slice;
if (isHostFoldableValue(tensorToSlice)) {
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
}
else {
auto sliceCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) {
Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
spatial::SpatYieldOp::create(rewriter, loc, computedSlice);
});
slice = sliceCompute.getResult(0);
}
slices.push_back(slice); slices.push_back(slice);
} }
@@ -5,15 +5,15 @@
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
namespace onnx_mlir { namespace onnx_mlir {
template <class ShapedType> template <class ShapedType>
@@ -105,7 +105,8 @@ inline auto getTensorShape(mlir::Value tensor) {
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) { inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType()); auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType()); auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape() && lhsType.getShape() == rhsType.getShape(); return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
&& lhsType.getShape() == rhsType.getShape();
} }
/// Slices a statically shaped tensor along one axis into contiguous pieces of /// Slices a statically shaped tensor along one axis into contiguous pieces of
@@ -5,12 +5,12 @@
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h" #include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include "WeightMaterialization.hpp"
#include "ShapeTilingUtils.hpp" #include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -28,7 +28,7 @@ bool isWeightLikeComputeOperand(Value value) {
while (auto* definingOp = value.getDefiningOp()) { while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second) if (!visited.insert(definingOp).second)
return false; return false;
if (hasWeightAlways(definingOp)) if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
return true; return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) { if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
@@ -0,0 +1,32 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
populateElementwisePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
populatePoolPatterns(patterns, ctx);
populateReduceMeanPatterns(patterns, ctx);
populateReluPatterns(patterns, ctx);
populateSigmoidPatterns(patterns, ctx);
populateSoftmaxPatterns(patterns, ctx);
populateConcatPatterns(patterns, ctx);
populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
}
} // namespace onnx_mlir
@@ -5,6 +5,8 @@
namespace onnx_mlir { namespace onnx_mlir {
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
@@ -0,0 +1,75 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
}
static bool isStaticTensorResult(Operation* op) {
return llvm::all_of(op->getResultTypes(), [](Type type) {
auto shapedType = dyn_cast<ShapedType>(type);
return shapedType && shapedType.hasStaticShape();
});
}
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
if (!op || !visited.insert(op).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return true;
if (!isStaticTensorResult(op))
return false;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return isHostFoldableValue(transposeOp.getData());
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return isHostFoldableValue(collapseShapeOp.getSrc());
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
return isHostFoldableValue(expandShapeOp.getSrc());
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return isHostFoldableValue(extractRowsOp.getInput());
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
return false;
}
} // namespace
bool isHostFoldableValue(Value value) {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(definingOp, visited);
}
bool isHostFoldableOp(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(op, visited);
}
} // namespace onnx_mlir
@@ -0,0 +1,12 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
bool isHostFoldableValue(mlir::Value value);
bool isHostFoldableOp(mlir::Operation* op);
} // namespace onnx_mlir
@@ -0,0 +1,29 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
bool hasFailure = false;
for (Operation& op : funcOp.getFunctionBody().front()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isHostFoldableOp(&op))
continue;
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
hasFailure = true;
}
return success(!hasFailure);
}
} // namespace onnx_mlir
@@ -0,0 +1,10 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -8,21 +8,17 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include <iterator>
#include <utility>
#include "Common/Common.hpp" #include "Common/Common.hpp"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -33,8 +29,6 @@ namespace onnx_mlir {
namespace { namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> { struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; } StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
@@ -44,71 +38,64 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {} ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override; void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
void populateEmptyFunction(func::FuncOp funcOp);
}; };
} // namespace } // namespace
static void foldSingleLaneComputeBatches(func::FuncOp funcOp) { static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext()); IRRewriter rewriter(funcOp.getContext());
SmallVector<spatial::SpatComputeBatch> batchOps;
funcOp.walk([&](spatial::SpatComputeBatch batchOp) { batchOps.push_back(batchOp); });
for (auto batchOp : batchOps) {
if (batchOp.getLaneCount() != 1)
continue;
auto loc = batchOp.getLoc();
rewriter.setInsertionPoint(batchOp);
auto computeOp =
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
computeOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
Block& templateBlock = batchOp.getBody().front();
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : templateBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(loc);
}
auto* newBlock =
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper; IRMapping mapper;
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments())) SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
mapper.map(oldArg, newArg); if (!computes.empty())
return;
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
rewriter.setInsertionPoint(returnOp);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLocs;
sourceTypes.reserve(funcOp.getNumArguments());
sourceLocs.reserve(funcOp.getNumArguments());
for (Value source : funcOp.getArguments()) {
sourceTypes.push_back(source.getType());
sourceLocs.push_back(source.getLoc());
}
auto newCompute = spatial::SpatCompute::create(
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
mapper.map(computeArg, blockArg);
newCompute.getProperties().setOperandSegmentSizes({0, static_cast<int>(sourceTypes.size())});
rewriter.setInsertionPointToEnd(newBlock); rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : templateBlock) for (Operation& op : funcOp.getOps())
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
batchOp.replaceAllUsesWith(computeOp.getResults()); auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
rewriter.eraseOp(batchOp); for (size_t i = 0; i < yield.getNumOperands(); ++i)
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
op.dropAllUses();
rewriter.eraseOp(&op);
} }
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
returnOp.setOperand(index, computeResult);
} }
void ONNXToSpatialPass::runOnOperation() { void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext(); MLIRContext* ctx = &getContext();
RewritePatternSet mergeActivationPatterns(ctx); RewritePatternSet prePatterns(ctx);
mergeActivationPatterns.add<onnxToArithConstant>(ctx); populatePrePatterns(prePatterns, ctx);
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx); if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx); llvm::dbgs() << "Failed to apply pre-patterns, continuing...\n";
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
mergeActivationPatterns.add<matMulToGemm>(ctx);
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
IRRewriter rewriter(moduleOp);
auto entryFunc = getPimEntryFunc(moduleOp); auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) { if (failed(entryFunc)) {
signalPassFailure(); signalPassFailure();
@@ -140,34 +127,23 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXReduceMeanV13Op>(); target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXSplitOp>(); target.addIllegalOp<ONNXSplitOp>();
RewritePatternSet patterns(ctx); RewritePatternSet conversionPatterns(ctx);
patterns.add<removeLRN>(ctx); populateConversionPatterns(conversionPatterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
populateElementwisePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
populatePoolPatterns(patterns, ctx);
populateReduceMeanPatterns(patterns, ctx);
populateReluPatterns(patterns, ctx);
populateSigmoidPatterns(patterns, ctx);
populateSoftmaxPatterns(patterns, ctx);
populateConcatPatterns(patterns, ctx);
populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure(); signalPassFailure();
return; return;
} }
foldSingleLaneComputeBatches(*entryFunc); RewritePatternSet earlyPostPatterns(ctx);
populateEarlyPostPatterns(earlyPostPatterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
signalPassFailure();
return;
}
// Count the number of compute ops and check they do not exceed the core count
if (coresCount != -1) { if (coresCount != -1) {
int computeOpsCount = 0; int computeOpsCount = 0;
for (auto& op : entryFunc->getFunctionBody().front().getOperations()) for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<spatial::SpatCompute>(op)) if (isa<spatial::SpatCompute>(op))
computeOpsCount++; computeOpsCount++;
@@ -185,355 +161,23 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc); annotateWeightsConstants(*entryFunc);
RewritePatternSet postPatterns(ctx);
populatePostPatterns(postPatterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
signalPassFailure();
return;
}
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
signalPassFailure();
return;
}
populateEmptyFunction(*entryFunc); populateEmptyFunction(*entryFunc);
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
signalPassFailure();
return;
}
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
signalPassFailure();
return;
}
// Dump to file for debug
dumpModule(moduleOp, "spatial0"); dumpModule(moduleOp, "spatial0");
} }
template <typename T>
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
Value source = funcSource(toRemoveOp);
rewriter.setInsertionPointAfter(toRemoveOp);
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
return false;
}
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
auto source = toRemoveOp.getSource();
rewriter.setInsertionPointAfter(toRemoveOp);
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
return false;
}
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
auto sources = toRemoveOp.getInputs();
rewriter.setInsertionPointAfter(toRemoveOp);
if (llvm::any_of(sources,
[](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
for (auto source : sources) {
sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc);
}
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = spatial::SpatConcatOp::create(rewriter,
loc,
toRemoveOp.getType(),
rewriter.getI64IntegerAttr(toRemoveOp.getDim()),
ValueRange(BB->getArguments()));
spatial::SpatYieldOp::create(rewriter, loc, newConcat.getOutput());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
for (auto source : sources) {
sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc);
}
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
return false;
}
static FailureOr<bool> sourceOperandHasWeightAlways(Operation* op) {
if (op == nullptr)
return false;
Operation* source = nullptr;
do {
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch>(*op)) {
return false;
}
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(*op)) {
auto tmpSource = extractSliceOp.getSource();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(*op)) {
auto tmpSource = extractRowsOp.getInput();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(*op)) {
auto tmpSource = expandShapeOp.getSrc();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto transposeOp = dyn_cast<ONNXTransposeOp>(*op)) {
auto tmpSource = transposeOp.getData();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(*op)) {
auto tmpSource = collapseShapeOp.getSrc();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto constantOp = dyn_cast<arith::ConstantOp>(*op)) {
source = constantOp;
}
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(*op)) {
bool res = false;
for (auto operand : concatOp.getOperands()) {
res |= hasWeightAlways(operand.getDefiningOp());
if (res)
return res;
}
return res;
}
else if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(*op)) {
bool res = false;
for (auto operand : concatOp.getOperands()) {
res |= hasWeightAlways(operand.getDefiningOp());
if (res)
return res;
}
return res;
}
else {
op->emitOpError("unsupported global instruction while promoting weight-backed operands into Spatial computes");
return failure();
}
}
while (source == nullptr);
return hasWeightAlways(source);
}
// TODO what we want to keep in global?
LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
bool keep = true;
while (keep) {
keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatExtractRowsOp>(instruction)
|| isa<func::ReturnOp>(instruction))
continue;
auto weightBacked = sourceOperandHasWeightAlways(&instruction);
if (failed(weightBacked))
return failure();
if (*weightBacked)
continue;
keep |= encapsulateSlice(rewriter, loc, &instruction);
keep |= encapsulator<tensor::ExpandShapeOp>(
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
keep |= encapsulator<ONNXTransposeOp>(
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
keep |= encapsulator<tensor::CollapseShapeOp>(
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
keep |= encapsulateConcat(rewriter, loc, &instruction);
}
}
return success();
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) {
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
markWeightAlways(constantOp);
});
}
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
for (auto compute : computes) {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
continue;
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
auto& oldBlock = compute.getBody().front();
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper);
if (failed(clonedValue))
return compute.emitError("failed to materialize promoted weight-like operand inside compute body");
mapper.map(oldArg, *clonedValue);
}
for (auto& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
compute.replaceAllUsesWith(newCompute);
compute.erase();
}
return success();
}
void ONNXToSpatialPass::populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(&getContext());
IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
if (!computes.empty())
return;
auto returnOp = llvm::cast<func::ReturnOp>(funcOp.getRegion().front().getTerminator());
rewriter.setInsertionPoint(returnOp);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
for (auto source : funcOp.getArguments()) {
sourceTypes.push_back(source.getType());
sourceLoc.push_back(source.getLoc());
}
auto newCompute = spatial::SpatCompute::create(
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
for (auto [bbArg, computeArg] : llvm::zip(BB->getArguments(), newCompute.getOperands()))
mapper.map(computeArg, bbArg);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sourceTypes.size()});
rewriter.setInsertionPointToEnd(BB);
for (Operation& inst : funcOp.getOps())
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst))
rewriter.clone(inst, mapper);
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
for (size_t i = 0; i < yield.getNumOperands(); ++i)
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
for (Operation& inst : llvm::make_early_inc_range(funcOp.getOps()))
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst)){
inst.dropAllUses();
rewriter.eraseOp(&inst);
}
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
returnOp.setOperand(index, computeResult);
}
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); } std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -5,9 +5,9 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -11,6 +11,7 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -49,6 +50,45 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult(); return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
} }
static Value transposeForSpatial(Value value,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
if (isHostFoldableValue(value))
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return computeOp.getResult(0);
}
static Value
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
if (isHostFoldableValue(value))
return tensor::ExpandShapeOp::create(rewriter,
loc,
resultType,
value,
SmallVector<ReassociationIndices> {
{0, 1}
});
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
resultType,
input,
SmallVector<ReassociationIndices> {
{0, 1}
});
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
return computeOp.getResult(0);
}
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> { struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@@ -81,6 +121,11 @@ static SmallVector<Value> materializeBatchRowSlices(Value matrix,
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType()); auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType); SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
if (isHostFoldableValue(matrix)) {
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
}
auto buildRowSlices = [&](Value matrixArg) { auto buildRowSlices = [&](Value matrixArg) {
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg); auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end()); return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
@@ -122,7 +167,8 @@ static SmallVector<Value> materializeBatchRowSlices(Value matrix,
rootValue = definingOp->getOperand(0); rootValue = definingOp->getOperand(0);
} }
return buildRowSlices(matrix); SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue);
} }
} // namespace } // namespace
@@ -175,13 +221,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) { if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter, c = expandRankOneBias(c, expandedType, rewriter, loc);
loc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
cType = expandedType; cType = expandedType;
} }
if (!cType.hasStaticShape()) { if (!cType.hasStaticShape()) {
@@ -196,25 +236,18 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
} }
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
SmallVector<Value> cSlices;
if (hasC && cHasNumOutRows)
cSlices = materializeBatchRowSlices(c, cType, rewriter, loc);
SmallVector<Value> gemvOps; SmallVector<Value> gemvOps;
gemvOps.reserve(numOutRows); gemvOps.reserve(static_cast<size_t>(numOutRows));
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) { for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
Value cSlice = c; Value cSlice = c;
if (hasC) { if (hasC) {
if (cHasNumOutRows) { if (cHasNumOutRows)
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; cSlice = cSlices[static_cast<size_t>(rowIdx)];
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
}
else if (!isVectorShape(getTensorShape(c))) { else if (!isVectorShape(getTensorShape(c))) {
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows"); gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
return failure(); return failure();
@@ -224,7 +257,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
auto gemvOp = ONNXGemmOp::create(rewriter, auto gemvOp = ONNXGemmOp::create(rewriter,
loc, loc,
outRowType, outRowType,
aSlice, aSlices[static_cast<size_t>(rowIdx)],
b, b,
cSlice, cSlice,
rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f),
@@ -267,13 +300,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) { if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter, c = expandRankOneBias(c, expandedType, rewriter, gemmLoc);
gemmLoc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
cType = expandedType; cType = expandedType;
} }
if (!cType.hasStaticShape()) { if (!cType.hasStaticShape()) {
@@ -305,13 +332,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
if (transA) { if (transA) {
auto aShape = aType.getShape(); auto aShape = aType.getShape();
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType()); auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType());
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})); a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc);
aType = cast<RankedTensorType>(a.getType());
} }
if (transB) { if (transB) {
auto bShape = bType.getShape(); auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc);
bType = cast<RankedTensorType>(b.getType()); bType = cast<RankedTensorType>(b.getType());
} }
@@ -335,7 +363,6 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue()); auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
auto bNumVSlices = aNumHSlices; auto bNumVSlices = aNumHSlices;
auto bLastVSliceSize = aLastHSliceSize;
auto cNumHSlices = bNumHSlices; auto cNumHSlices = bNumHSlices;
auto cLastHSliceSize = bLastHSliceSize; auto cLastHSliceSize = bLastHSliceSize;
auto outNumHSlices = cNumHSlices; auto outNumHSlices = cNumHSlices;
@@ -469,12 +496,15 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
if (gemmOpAdaptor.getTransB()) { if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape(); auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
bType = cast<RankedTensorType>(b.getType()); bType = cast<RankedTensorType>(b.getType());
} }
(void) bType; (void) bType;
if (!isHostFoldableValue(b))
return failure();
Value sharedBias; Value sharedBias;
if (hasC) { if (hasC) {
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
@@ -484,13 +514,7 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
auto cType = cast<RankedTensorType>(c.getType()); auto cType = cast<RankedTensorType>(c.getType());
if (cType.getRank() == 1) { if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter, c = expandRankOneBias(c, expandedType, rewriter, loc);
loc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
cType = cast<RankedTensorType>(c.getType()); cType = cast<RankedTensorType>(c.getType());
} }
if (!cType.hasStaticShape()) { if (!cType.hasStaticShape()) {
@@ -2,11 +2,11 @@
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -36,9 +36,9 @@ static Value extractBatchMatrix(Value value,
SmallVector<OpFoldResult> sizes = { SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)}; rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides);
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType()); auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
auto buildMatrix = [&](Value input) -> Value {
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
return tensor::CollapseShapeOp::create(rewriter, return tensor::CollapseShapeOp::create(rewriter,
loc, loc,
matrixType, matrixType,
@@ -47,38 +47,16 @@ static Value extractBatchMatrix(Value value,
{0, 1}, {0, 1},
{2} {2}
}); });
} };
static bool isConstantLikeOperand(Value value) { if (isHostFoldableValue(value))
llvm::SmallPtrSet<Operation*, 8> visited; return buildMatrix(value);
while (auto* definingOp = value.getDefiningOp()) { auto batchMatrixCompute =
if (!visited.insert(definingOp).second) createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
return false; spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
if (definingOp->hasTrait<OpTrait::ConstantLike>()) });
return true; return batchMatrixCompute.getResult(0);
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
} }
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) { static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
@@ -107,15 +85,31 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
perm = {0, 2, 1}; perm = {0, 2, 1};
} }
auto transposeCompute = auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) { Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
Value transposed =
ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed); spatial::SpatYieldOp::create(rewriter, loc, transposed);
}); });
return transposeCompute.getResult(0); return transposeCompute.getResult(0);
} }
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> { struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -157,7 +151,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
} }
Location loc = matmulOp.getLoc(); Location loc = matmulOp.getLoc();
bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB()); bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
Value lhs = matmulOp.getA(); Value lhs = matmulOp.getA();
Value rhs = matmulOp.getB(); Value rhs = matmulOp.getB();
@@ -193,8 +187,14 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
rewriter.getBoolAttr(false), rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false)) rewriter.getBoolAttr(false))
.getY(); .getY();
if (useTransposedForm) if (useTransposedForm) {
gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0})); auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
gemmResult = transposeCompute.getResult(0);
}
rewriter.replaceOp(matmulOp, gemmResult); rewriter.replaceOp(matmulOp, gemmResult);
return success(); return success();
} }
@@ -215,24 +215,30 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
rewriter.getBoolAttr(false), rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false)) rewriter.getBoolAttr(false))
.getY(); .getY();
if (useTransposedForm) auto batchResultCompute =
gemmResult = ONNXTransposeOp::create( createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
rewriter, Value resultMatrix = input;
if (useTransposedForm) {
resultMatrix = ONNXTransposeOp::create(rewriter,
loc, loc,
RankedTensorType::get({m, n}, outType.getElementType()), RankedTensorType::get({m, n}, outType.getElementType()),
gemmResult, input,
rewriter.getI64ArrayAttr({1, 0})); rewriter.getI64ArrayAttr({1, 0}));
batchResults.push_back(tensor::ExpandShapeOp::create(rewriter, }
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc, loc,
batchedOutType, batchedOutType,
gemmResult, resultMatrix,
SmallVector<ReassociationIndices> { SmallVector<ReassociationIndices> {
{0, 1}, {0, 1},
{2} {2}
})); });
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
batchResults.push_back(batchResultCompute.getResult(0));
} }
Value result = createSpatConcat(rewriter, loc, /*axis=*/0, batchResults); Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
rewriter.replaceOp(matmulOp, result); rewriter.replaceOp(matmulOp, result);
return success(); return success();
} }
@@ -6,7 +6,8 @@
#include <algorithm> #include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -81,6 +82,24 @@ createAverageCompute(Value input, RankedTensorType resultType, ConversionPattern
return computeOp.getResult(0); return computeOp.getResult(0);
} }
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value buildReduceMeanKeepdims(Value input, static Value buildReduceMeanKeepdims(Value input,
ArrayRef<bool> reducedAxes, ArrayRef<bool> reducedAxes,
int64_t axis, int64_t axis,
@@ -100,7 +119,7 @@ static Value buildReduceMeanKeepdims(Value input,
for (Value slice : slices) for (Value slice : slices)
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc)); reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
return createSpatConcat(rewriter, loc, axis, reducedSlices); return concatValues(reducedSlices, axis, rewriter, loc);
} }
static Value squeezeReducedAxes(Value keepdimsValue, static Value squeezeReducedAxes(Value keepdimsValue,
@@ -115,9 +134,16 @@ static Value squeezeReducedAxes(Value keepdimsValue,
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element}); return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
} }
return tensor::CollapseShapeOp::create( auto reassociation = buildCollapseReassociation(reducedAxes);
rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes)) if (isHostFoldableValue(keepdimsValue))
.getResult(); return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
auto squeezeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, ValueRange {keepdimsValue}, [&](Value input) {
Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation);
spatial::SpatYieldOp::create(rewriter, loc, collapsed);
});
return squeezeCompute.getResult(0);
} }
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> { struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
@@ -31,8 +31,8 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
} }
template <typename PoolOp> template <typename PoolOp>
static FailureOr<Value> static FailureOr<Value> concatAlongAxis(
concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) { ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
if (values.empty()) { if (values.empty()) {
poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty"); poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty");
return failure(); return failure();
@@ -68,8 +68,8 @@ reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation*
return reduced; return reduced;
} }
static FailureOr<Value> static FailureOr<Value> scaleAverageWindow(
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) { ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
if (divisor <= 0) { if (divisor <= 0) {
op->emitOpError("AveragePool divisor must be positive"); op->emitOpError("AveragePool divisor must be positive");
return failure(); return failure();
@@ -2,7 +2,8 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -32,6 +33,24 @@ static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewrit
return computeOp.getResult(0); return computeOp.getResult(0);
} }
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value static Value
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
@@ -47,7 +66,7 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe
for (Value slice : slices) for (Value slice : slices)
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc)); rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
return createSpatConcat(rewriter, loc, axis, rebuiltSlices); return concatValues(rebuiltSlices, axis, rewriter, loc);
} }
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> { struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
@@ -92,8 +111,13 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
Value transposedInput = preTransposeCompute.getResult(0); Value transposedInput = preTransposeCompute.getResult(0);
Value transposedResult = buildSoftmax( Value transposedResult = buildSoftmax(
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc()); transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
result = ONNXTransposeOp::create( auto postTransposeCompute =
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation)); createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
Value transposed = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
result = postTransposeCompute.getResult(0);
} }
rewriter.replaceOp(softmaxOp, result); rewriter.replaceOp(softmaxOp, result);
@@ -2,6 +2,8 @@
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -18,7 +20,17 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
auto inputs = adaptor.getInputs(); auto inputs = adaptor.getInputs();
int64_t axis = adaptor.getAxis(); int64_t axis = adaptor.getAxis();
if (llvm::all_of(inputs, isHostFoldableValue)) {
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs)); rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
return success();
}
auto computeOp = createSpatCompute(
rewriter, maxpoolOp.getLoc(), TypeRange {maxpoolOp.getResult().getType()}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(
rewriter, maxpoolOp.getLoc(), createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, args));
});
rewriter.replaceOp(maxpoolOp, computeOp.getResults());
return success(); return success();
} }
@@ -6,7 +6,7 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -3,7 +3,10 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
@@ -95,18 +98,33 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
return success(); return success();
} }
SmallVector<ReassociationIndices> reassociation; auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
if (sourceType.getRank() > resultType.getRank() if (isHostFoldableValue(adaptor.getData())) {
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) { rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return success(); return success();
} }
if (sourceType.getRank() < resultType.getRank() auto computeOp = createSpatCompute<1>(
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) { rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation); Value reshaped = buildReshape(data);
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
});
rewriter.replaceOp(reshapeOp, computeOp.getResults());
return success(); return success();
} };
SmallVector<ReassociationIndices> reassociation;
if (sourceType.getRank() > resultType.getRank()
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
return replaceWithReshape([&](Value data) {
return tensor::CollapseShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
});
if (sourceType.getRank() < resultType.getRank()
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
return replaceWithReshape([&](Value data) {
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
});
return failure(); return failure();
} }
@@ -6,7 +6,7 @@
#include <algorithm> #include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -2,7 +2,9 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
@@ -47,18 +49,42 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
outputs.reserve(splitOp.getNumResults()); outputs.reserve(splitOp.getNumResults());
int64_t offset = 0; int64_t offset = 0;
SmallVector<RankedTensorType> resultTypes;
resultTypes.reserve(splitOp.getNumResults());
SmallVector<int64_t> sliceSizes;
sliceSizes.reserve(splitOp.getNumResults());
for (Value result : splitOp.getResults()) { for (Value result : splitOp.getResults()) {
auto resultType = dyn_cast<RankedTensorType>(result.getType()); auto resultType = dyn_cast<RankedTensorType>(result.getType());
if (!resultType || !resultType.hasStaticShape()) if (!resultType || !resultType.hasStaticShape())
return failure(); return failure();
int64_t sliceSize = resultType.getShape()[axis]; resultTypes.push_back(resultType);
sliceSizes.push_back(resultType.getShape()[axis]);
}
if (isHostFoldableValue(adaptor.getInput())) {
for (int64_t sliceSize : sliceSizes) {
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc())); outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
offset += sliceSize; offset += sliceSize;
} }
rewriter.replaceOp(splitOp, outputs); rewriter.replaceOp(splitOp, outputs);
return success(); return success();
} }
auto computeOp = createSpatCompute<1>(
rewriter, splitOp.getLoc(), TypeRange(splitOp.getResultTypes()), {}, adaptor.getInput(), [&](Value input) {
SmallVector<Value> runtimeOutputs;
runtimeOutputs.reserve(resultTypes.size());
int64_t runtimeOffset = 0;
for (int64_t sliceSize : sliceSizes) {
runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc()));
runtimeOffset += sliceSize;
}
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs);
});
rewriter.replaceOp(splitOp, computeOp.getResults());
return success();
}
}; };
} // namespace } // namespace
@@ -0,0 +1,265 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isWeightMaterializationHelperUser(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op);
}
static bool canPromoteInputBlockArgument(BlockArgument arg) {
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
}
static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
if (batchOp.getLaneCount() != 1)
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
auto loc = batchOp.getLoc();
rewriter.setInsertionPoint(batchOp);
auto computeOp =
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
computeOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
Block& templateBlock = batchOp.getBody().front();
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(templateBlock.getNumArguments());
blockArgLocs.reserve(templateBlock.getNumArguments());
for (BlockArgument arg : templateBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(loc);
}
auto* newBlock =
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : templateBlock)
rewriter.clone(op, mapper);
batchOp->replaceAllUsesWith(computeOp->getResults());
rewriter.eraseOp(batchOp);
return success();
}
};
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatComputeBatch::create(rewriter,
compute.getLoc(),
compute.getResultTypes(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
newWeights,
newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
mapper.map(oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
} // namespace
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
}
void annotateWeightsConstants(func::FuncOp funcOp) {
funcOp.walk([&](arith::ConstantOp constantOp) {
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
markWeightAlways(constantOp);
});
}
} // namespace onnx_mlir
@@ -0,0 +1,14 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
namespace onnx_mlir {
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -0,0 +1,25 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<onnxToArithConstant>(ctx);
patterns.add<convAddToConvWithBiasLeft>(ctx);
patterns.add<convAddToConvWithBiasRight>(ctx);
patterns.add<matMulAddToGemm>(ctx);
patterns.add<matMulToGemm>(ctx);
patterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(patterns, ctx);
}
} // namespace onnx_mlir
@@ -0,0 +1,10 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir
@@ -0,0 +1,224 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
SmallVector<int32_t> coreIds;
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++));
return coreIds;
}
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
Value input = mapper.lookup(sendTensorBatchOp.getInput());
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
if (concatOp.getDim() == 0)
if (Value packedInput =
createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorBatchOp.getLoc()))
input = packedInput;
pim::PimSendTensorBatchOp::create(
rewriter, sendTensorBatchOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
}
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
receiveTensorBatchOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
mapper.map(receiveTensorBatchOp.getOutput(), received);
}
} // namespace
LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front();
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
SmallVector<Value> batchInputs;
if (!computeBatchOp.getInputs().empty())
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
rewriter.setInsertionPointAfter(computeBatchOp);
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
loc,
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
ValueRange(batchWeights),
ValueRange(batchInputs));
coreBatchOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : oldBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc());
}
Block* newBlock =
rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
rewriter.setInsertionPointToStart(newBlock);
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
newArg,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, newArg))
.getOutput();
mapper.map(oldArg, copied);
}
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
if (auto mapped = mapper.lookupOrNull(capturedTensor))
return mapped;
auto capturedType = cast<ShapedType>(capturedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
capturedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, capturedTensor))
.getOutput();
mapper.map(capturedTensor, copied);
return copied;
};
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : oldBlock) {
if (isa<spatial::SpatYieldOp>(op))
continue;
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
pim::PimSendBatchOp::create(rewriter,
loc,
mapper.lookup(sendBatchOp.getInput()),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
sendBatchOp.getTargetCoreIdsAttr());
continue;
}
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
continue;
}
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
receiveBatchOp.getSourceCoreIdsAttr())
.getOutput();
mapper.map(receiveBatchOp.getOutput(), received);
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper);
auto clonedTensor = cloned->getResult(0);
auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
clonedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, clonedTensor))
.getOutput();
mapper.map(toTensorOp.getResult(), copied);
continue;
}
}
for (Value operand : op.getOperands()) {
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock)
continue;
materializeCapturedTensor(operand);
}
Operation* cloned = rewriter.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
rewriter.setInsertionPointToEnd(newBlock);
PimHaltOp::create(rewriter, loc);
return success();
}
} // namespace onnx_mlir
@@ -0,0 +1,10 @@
#pragma once
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
namespace onnx_mlir {
mlir::LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -4,8 +4,16 @@ add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp SpatialToPimPass.cpp
BatchCoreLoweringPatterns.cpp
ChannelLoweringPatterns.cpp
Cleanup.cpp
Common.cpp Common.cpp
Patterns.cpp ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp
GlobalTensorMaterialization.cpp
PhaseVerification.cpp
ReturnPathNormalization.cpp
TensorPackingPatterns.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
@@ -0,0 +1,136 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
pim::PimSendOp::create(rewriter,
op.getLoc(),
op.getInput(),
getTensorSizeInBytesAttr(rewriter, op.getInput()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
rewriter.eraseOp(op);
return success();
}
};
struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveOp op, PatternRewriter& rewriter) const override {
if (op->use_empty()) {
rewriter.eraseOp(op);
return success();
}
auto outputType = cast<ShapedType>(op.getResult().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received = pim::PimReceiveOp::create(rewriter,
op.getLoc(),
op.getResult().getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, op.getResult()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(op.getTargetCoreIds().size());
for (int32_t targetCoreId : op.getTargetCoreIds())
targetCoreIds.push_back(toPimCoreId(targetCoreId));
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(op);
return success();
}
};
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(op.getSourceCoreIds().size());
for (int32_t sourceCoreId : op.getSourceCoreIds())
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received =
pim::PimReceiveTensorOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatExtractRowsOp op, PatternRewriter& rewriter) const override {
auto inputType = cast<RankedTensorType>(op.getInput().getType());
SmallVector<Value> replacements;
replacements.reserve(op.getNumResults());
for (auto [rowIndex, output] : llvm::enumerate(op.getOutputs())) {
auto outputType = cast<RankedTensorType>(output.getType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
rewriter.getIndexAttr(inputType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
replacements.push_back(
tensor::ExtractSliceOp::create(rewriter, op.getLoc(), outputType, op.getInput(), offsets, sizes, strides)
.getResult());
}
rewriter.replaceOp(op, replacements);
return success();
}
};
struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatConcatOp op, PatternRewriter& rewriter) const override {
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value concatenated =
pim::PimConcatOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), op.getAxisAttr(), op.getInputs(), outputBuffer)
.getOutput();
rewriter.replaceOp(op, concatenated);
return success();
}
};
} // namespace
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
patterns.add<ChannelSendLowering,
ChannelReceiveLowering,
ChannelSendTensorLowering,
ChannelReceiveTensorLowering,
ExtractRowsLowering,
ConcatLowering>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -0,0 +1,9 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir
@@ -0,0 +1,42 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
while (!pendingOps.empty()) {
bool erasedAnyOp = false;
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
Operation* opToRemove = *it;
if (!opToRemove->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(opToRemove);
it = pendingOps.erase(it);
erasedAnyOp = true;
}
if (erasedAnyOp)
continue;
for (Operation* opToRemove : pendingOps) {
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
for (Operation* user : opToRemove->getUsers()) {
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
}
}
return failure();
}
return success();
}
} // namespace onnx_mlir
@@ -0,0 +1,11 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
namespace onnx_mlir {
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -0,0 +1,44 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigned operandNumber) {
auto getInputIndex = [operandNumber](Operation* op, unsigned inputCount) -> std::optional<unsigned> {
if (inputCount == 0)
return std::nullopt;
unsigned inputBegin = op->getNumOperands() - inputCount;
if (operandNumber < inputBegin)
return std::nullopt;
return operandNumber - inputBegin;
};
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
return getInputIndex(owner, compute.getInputs().size());
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner))
return getInputIndex(owner, computeBatch.getInputs().size());
return std::nullopt;
}
void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
Operation* owner,
unsigned inputIndex,
Value replacement) {
Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument = body.getArgument(inputIndex);
rewriter.startOpModification(owner);
bodyArgument.replaceAllUsesWith(replacement);
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
compute.getInputsMutable().erase(inputIndex);
else
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
body.eraseArgument(inputIndex);
rewriter.finalizeOpModification(owner);
}
} // namespace onnx_mlir
@@ -0,0 +1,17 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include <optional>
namespace onnx_mlir {
std::optional<unsigned> getDirectComputeLikeInputIndex(mlir::Operation* owner, unsigned operandNumber);
void replaceAndEraseDirectComputeLikeInput(mlir::PatternRewriter& rewriter,
mlir::Operation* owner,
unsigned inputIndex,
mlir::Value replacement);
} // namespace onnx_mlir
@@ -0,0 +1,213 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
static bool isChannelUseChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
ONNXTransposeOp,
pim::PimTransposeOp>(op);
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
return static_cast<int32_t>(fallbackCoreId++);
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
return failure();
if (requireReturnUse
&& (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin())))
return failure();
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 1)
return failure();
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
SmallVector<Operation*> reverseChain;
Value currentValue = yieldOp.getOperands().front();
Value blockArg = block.getArgument(0);
while (currentValue != blockArg) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp))
return failure();
reverseChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
}
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
return failure();
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
return success();
}
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
}))
return false;
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 0)
return false;
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return false;
rewriter.setInsertionPoint(computeOp);
IRMapping mapping;
for (Operation& op : block.without_terminator()) {
cloneMappedHelperOperands(&op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
computeOp.getResult(0).replaceAllUsesWith(replacement);
return true;
}
} // namespace
void markOpToRemove(CoreLoweringState& state, Operation* op) {
if (!llvm::is_contained(state.operationsToRemove, op))
state.operationsToRemove.push_back(op);
}
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
return success();
SmallVector<Operation*> helperChain;
if (succeeded(collectHelperComputeChain(computeOp, helperChain)))
return success();
auto& block = computeOp.getRegion().front();
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
if (!receiveOp || blockArg.use_empty())
continue;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
auto outputType = cast<ShapedType>(blockArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
Value received = PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
blockArg.replaceAllUsesWith(received);
markOpToRemove(state, receiveOp);
}
if (computeOp.getNumResults() != yieldOp.getNumOperands())
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
if (result.use_empty())
continue;
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
ReturnPathLoweringResult returnPathResult =
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
if (returnPathResult == ReturnPathLoweringResult::Failure)
return failure();
if (returnPathResult == ReturnPathLoweringResult::Handled)
continue;
auto resultUses = result.getUses();
if (rangeLength(resultUses) == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<spatial::SpatChannelSendOp>(resultUser))
continue;
}
return computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering");
}
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
SmallVector<Value> computeWeights;
if (!computeOp.getWeights().empty())
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
rewriter.setInsertionPointAfter(computeOp);
auto coreOp = PimCoreOp::create(rewriter,
loc,
ValueRange(computeWeights),
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
if (!blockArg.use_empty())
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
block.eraseArguments(0, block.getNumArguments());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock);
rewriter.setInsertionPointToEnd(tempComputeBlock);
PimHaltOp::create(rewriter, computeOp.getLoc());
return success();
}
} // namespace onnx_mlir
@@ -0,0 +1,21 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
struct CoreLoweringState {
size_t& nextCoreId;
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
};
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -6,16 +6,17 @@
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;
@@ -23,33 +24,33 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) { std::string name = baseName.str();
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) { unsigned suffix = 0;
unsigned inputCount = compute.getInputs().size(); while (SymbolTable::lookupSymbolIn(symbolTableOp, name))
if (inputCount == 0) name = (baseName + "_" + Twine(suffix++)).str();
return std::nullopt; return name;
unsigned inputBegin = compute->getNumOperands() - inputCount;
if (operandNumber < inputBegin)
return std::nullopt;
return operandNumber - inputBegin;
}
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner)) {
unsigned inputCount = computeBatch.getInputs().size();
if (inputCount == 0)
return std::nullopt;
unsigned inputBegin = computeBatch->getNumOperands() - inputCount;
if (operandNumber < inputBegin)
return std::nullopt;
return operandNumber - inputBegin;
}
return std::nullopt;
} }
static memref::GlobalOp createPrivateMemrefGlobalWithUniqueName(PatternRewriter& rewriter,
Location loc,
ModuleOp moduleOp,
StringRef baseName,
MemRefType type,
Attribute initialValue = {},
UnitAttr constant = {}) {
std::string symbolName = makeUniqueSymbolName(moduleOp, baseName);
return memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(symbolName),
rewriter.getStringAttr("private"),
TypeAttr::get(type),
initialValue,
constant,
IntegerAttr {});
}
// Sinks top-level tensor slices into compute regions so later lowering sees local runtime work.
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> { struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -59,7 +60,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
for (auto& uses : extractSliceOp->getUses()) { for (auto& uses : extractSliceOp->getUses()) {
if (isa<spatial::SpatCompute>(uses.getOwner())) { if (isa<spatial::SpatCompute>(uses.getOwner())) {
if (!getDirectComputeInputIndex(uses.getOwner(), uses.getOperandNumber())) if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
return failure(); return failure();
} }
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) { else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
@@ -72,7 +73,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) { for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) { if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
auto inputIndex = getDirectComputeInputIndex(spatCompute, uses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgIndex = *inputIndex; auto BBArgIndex = *inputIndex;
@@ -87,14 +88,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)}); mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
} }
rewriter.startOpModification(spatCompute.getOperation()); replaceAndEraseDirectComputeLikeInput(
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute.getOperation()]); rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
} }
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) { else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, uses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgIndex = *inputIndex; auto BBArgIndex = *inputIndex;
@@ -109,11 +107,8 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)}); mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
} }
rewriter.startOpModification(spatComputeBatch.getOperation()); replaceAndEraseDirectComputeLikeInput(
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatComputeBatch.getOperation()]); rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
} }
else { else {
{ {
@@ -148,11 +143,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
} }
}; };
// Turns runtime constants consumed by compute regions into private globals and local loads.
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> { struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
static int i = 0;
Location loc = constantOp.getLoc(); Location loc = constantOp.getLoc();
if (hasWeightAlways(constantOp)) if (hasWeightAlways(constantOp))
@@ -177,15 +172,14 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
if (constRankedTensorType) { if (constRankedTensorType) {
mlir::MemRefType memRefType = mlir::MemRefType memRefType =
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType()); mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
std::string argName = "const_" + std::to_string(i++); auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
memref::GlobalOp::create(rewriter,
loc, loc,
rewriter.getStringAttr(argName), constantOp->getParentOfType<ModuleOp>(),
rewriter.getStringAttr("private"), "const",
TypeAttr::get(memRefType), memRefType,
constantOp.getValueAttr(), constantOp.getValueAttr(),
rewriter.getUnitAttr(), rewriter.getUnitAttr());
{}); std::string argName = globalOp.getSymName().str();
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst; llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
@@ -193,11 +187,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
auto constUsers = constUses.getOwner(); auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) { if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgIndex = *inputIndex; auto BBArgIndex = *inputIndex;
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) { if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
@@ -206,18 +199,14 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()}); mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
} }
rewriter.startOpModification(spatCompute.getOperation()); replaceAndEraseDirectComputeLikeInput(
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute.getOperation()]); rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
} }
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) { else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgIndex = *inputIndex; auto BBArgIndex = *inputIndex;
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) { if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
@@ -226,11 +215,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()}); mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
} }
rewriter.startOpModification(spatComputeBatch.getOperation()); replaceAndEraseDirectComputeLikeInput(rewriter,
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatComputeBatch.getOperation()]); spatComputeBatch.getOperation(),
spatComputeBatch.getInputsMutable().erase(BBArgIndex); BBArgIndex,
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex); mapSpatComputeToConst[spatComputeBatch.getOperation()]);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
} }
else { else {
{ {
@@ -272,34 +260,26 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
auto constUsers = constUses.getOwner(); auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) { if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgIndex = *inputIndex; auto BBArgIndex = *inputIndex;
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto newConst = rewriter.clone(*constantOp); auto newConst = rewriter.clone(*constantOp);
rewriter.startOpModification(spatCompute.getOperation()); replaceAndEraseDirectComputeLikeInput(
BBArgValue.replaceAllUsesWith(newConst->getResult(0)); rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
} }
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) { else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgIndex = *inputIndex; auto BBArgIndex = *inputIndex;
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto newConst = rewriter.clone(*constantOp); auto newConst = rewriter.clone(*constantOp);
rewriter.startOpModification(spatComputeBatch.getOperation()); replaceAndEraseDirectComputeLikeInput(
BBArgValue.replaceAllUsesWith(newConst->getResult(0)); rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
} }
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) { else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
if (!mapSpatComputeToConst.contains(parent)) { if (!mapSpatComputeToConst.contains(parent)) {
@@ -321,11 +301,13 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
} }
} }
} }
if (constantOp->use_empty())
rewriter.eraseOp(constantOp); rewriter.eraseOp(constantOp);
return success(); return success();
} }
}; };
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> { struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -352,52 +334,36 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
mlir::MemRefType memRefType = mlir::MemRefType memRefType =
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType()); mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
std::string argName = "arg_" + std::to_string(index); std::string baseName = ("arg_" + Twine(index)).str();
auto globalOp = createPrivateMemrefGlobalWithUniqueName(
memref::GlobalOp::create(rewriter, rewriter, loc, funcOp->getParentOfType<ModuleOp>(), baseName, memRefType);
loc, std::string argName = globalOp.getSymName().str();
rewriter.getStringAttr(argName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) { for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
auto argUser = argUses.getOwner(); auto argUser = argUses.getOwner();
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) { if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
auto inputIndex = getDirectComputeInputIndex(spatCompute, argUses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgIndex = *inputIndex; auto BBArgIndex = *inputIndex;
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create( auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(spatCompute.getOperation()); replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
BBArgValue.replaceAllUsesWith(toTensor);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
} }
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) { else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, argUses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgIndex = *inputIndex; auto BBArgIndex = *inputIndex;
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create( auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(spatComputeBatch.getOperation()); replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, toTensor);
BBArgValue.replaceAllUsesWith(toTensor);
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
} }
else { else {
rewriter.setInsertionPoint(argUser); rewriter.setInsertionPoint(argUser);
@@ -416,7 +382,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
}; };
} // namespace } // namespace
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) { void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>( patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
patterns.getContext()); patterns.getContext());
} }
@@ -0,0 +1,9 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
}
@@ -1,10 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns);
}
@@ -0,0 +1,20 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult verifySpatialToPimBoundary(ModuleOp moduleOp) {
bool hasFailure = false;
moduleOp.walk([&](Operation* op) {
if (op->getDialect()->getNamespace() != "spat")
return;
op->emitError("illegal Spatial operation remains after Spatial-to-PIM lowering");
hasFailure = true;
});
return success(!hasFailure);
}
} // namespace onnx_mlir
@@ -0,0 +1,9 @@
#pragma once
#include "mlir/IR/BuiltinOps.h"
namespace onnx_mlir {
mlir::LogicalResult verifySpatialToPimBoundary(mlir::ModuleOp moduleOp);
} // namespace onnx_mlir
@@ -0,0 +1,587 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/SymbolTable.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
struct ReturnUseInfo {
size_t returnIndex;
SmallVector<Operation*> helperChain;
};
struct ConcatReturnUseInfo {
size_t returnIndex;
SmallVector<int64_t> sliceOffsets;
SmallVector<int64_t> concatShape;
SmallVector<Operation*> concatChain;
SmallVector<Operation*> helperChain;
};
static bool isReturnHelperChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
ONNXTransposeOp,
pim::PimTransposeOp>(op);
}
static void markOpToRemove(ReturnPathState& state, Operation* op) {
if (!llvm::is_contained(state.operationsToRemove, op))
state.operationsToRemove.push_back(op);
}
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
std::string name = baseName.str();
unsigned suffix = 0;
while (SymbolTable::lookupSymbolIn(symbolTableOp, name))
name = (baseName + "_" + Twine(suffix++)).str();
return name;
}
static int64_t computeFlatElementIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> shape) {
int64_t flatIndex = 0;
for (size_t i = 0; i < shape.size(); ++i) {
flatIndex *= shape[i];
flatIndex += indices[i];
}
return flatIndex;
}
static SmallVector<int64_t> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
SmallVector<int64_t> indices(shape.size(), 0);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
indices[dim] = flatIndex % shape[dim];
flatIndex /= shape[dim];
}
return indices;
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
SmallVectorImpl<Operation*>& helperChain) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
return failure();
if (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin()))
return failure();
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 1)
return failure();
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
SmallVector<Operation*> reverseChain;
Value currentValue = yieldOp.getOperands().front();
Value blockArg = block.getArgument(0);
while (currentValue != blockArg) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || definingOp->getBlock() != &block || !isReturnHelperChainOp(definingOp))
return failure();
reverseChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
}
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
return failure();
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
return success();
}
static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
auto uses = value.getUses();
if (rangeLength(uses) != 1)
return std::nullopt;
SmallVector<Operation*> helperChain;
Value currentValue = value;
Operation* currentUser = uses.begin()->getOwner();
while (isReturnHelperChainOp(currentUser)) {
helperChain.push_back(currentUser);
auto currentUses = currentUser->getResult(0).getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentValue = currentUser->getResult(0);
currentUser = currentUses.begin()->getOwner();
}
if (!isa<func::ReturnOp>(currentUser))
return std::nullopt;
return ReturnUseInfo {
currentValue.getUses().begin()->getOperandNumber(),
std::move(helperChain),
};
}
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
auto getConcatResult = [](Operation* op) -> Value {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getResult();
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
return spatialConcat.getOutput();
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
return pimConcat.getOutput();
return {};
};
auto getConcatAxis = [](Operation* op) -> std::optional<int64_t> {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getDim();
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
return spatialConcat.getAxis();
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
return pimConcat.getAxis();
return std::nullopt;
};
auto getConcatOperands = [](Operation* op) -> OperandRange {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getOperands();
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
return spatialConcat.getInputs();
return cast<pim::PimConcatOp>(op).getInputs();
};
auto uses = value.getUses();
if (rangeLength(uses) != 1
|| !isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
return std::nullopt;
auto valueType = dyn_cast<ShapedType>(value.getType());
if (!valueType || !valueType.hasStaticShape())
return std::nullopt;
SmallVector<int64_t> sliceOffsets(valueType.getRank(), 0);
SmallVector<int64_t> concatShape(valueType.getShape().begin(), valueType.getShape().end());
SmallVector<Operation*> concatChain;
Value currentValue = value;
Operation* currentUser = uses.begin()->getOwner();
while (isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(currentUser)) {
concatChain.push_back(currentUser);
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
int64_t axis = *getConcatAxis(currentUser);
for (Value operand : getConcatOperands(currentUser).take_front(operandIndex))
sliceOffsets[axis] += cast<ShapedType>(operand.getType()).getShape()[axis];
Value concatResult = getConcatResult(currentUser);
auto concatType = dyn_cast<ShapedType>(concatResult.getType());
if (!concatType || !concatType.hasStaticShape())
return std::nullopt;
concatShape.assign(concatType.getShape().begin(), concatType.getShape().end());
currentValue = concatResult;
auto currentUses = currentValue.getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentUser = currentUses.begin()->getOwner();
}
SmallVector<Operation*> helperChain;
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
return std::nullopt;
if (failed(collectHelperComputeChain(helperCompute, helperChain)))
return std::nullopt;
currentValue = helperCompute.getResult(0);
auto currentUses = currentValue.getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentUser = currentUses.begin()->getOwner();
}
while (isReturnHelperChainOp(currentUser)) {
helperChain.push_back(currentUser);
auto currentUses = currentUser->getResult(0).getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentValue = currentUser->getResult(0);
currentUser = currentUses.begin()->getOwner();
}
if (!isa<func::ReturnOp>(currentUser))
return std::nullopt;
return ConcatReturnUseInfo {
currentValue.getUses().begin()->getOperandNumber(),
std::move(sliceOffsets),
std::move(concatShape),
std::move(concatChain),
std::move(helperChain),
};
}
static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndices,
ArrayRef<int64_t> sourceShape,
ArrayRef<Operation*> helperChain,
SmallVectorImpl<int64_t>& mappedIndices) {
SmallVector<int64_t> currentIndices(sourceIndices.begin(), sourceIndices.end());
SmallVector<int64_t> currentShape(sourceShape.begin(), sourceShape.end());
auto reshapeToResultShape = [&](Operation* op) -> LogicalResult {
auto resultType = dyn_cast<ShapedType>(op->getResult(0).getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape);
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
currentIndices = expandFlatElementIndex(flatIndex, currentShape);
return success();
};
for (Operation* op : helperChain) {
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
auto hasStaticValues = [](ArrayRef<int64_t> values) {
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
};
if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes())
|| !hasStaticValues(extractSliceOp.getStaticStrides()))
return failure();
SmallVector<int64_t> nextIndices;
nextIndices.reserve(currentIndices.size());
for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices,
extractSliceOp.getStaticOffsets(),
extractSliceOp.getStaticSizes(),
extractSliceOp.getStaticStrides())) {
if (stride != 1 || index < offset || index >= offset + size)
return failure();
nextIndices.push_back(index - offset);
}
auto resultType = dyn_cast<ShapedType>(extractSliceOp.getResult().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
currentIndices = std::move(nextIndices);
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) {
SmallVector<int64_t> nextIndices(currentIndices.size());
SmallVector<int64_t> nextShape(currentShape.size());
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) {
int64_t sourceIndex = attr.getInt();
nextIndices[destIndex] = currentIndices[sourceIndex];
nextShape[destIndex] = currentShape[sourceIndex];
}
currentIndices = std::move(nextIndices);
currentShape = std::move(nextShape);
continue;
}
if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) {
SmallVector<int64_t> nextIndices(currentIndices.size());
SmallVector<int64_t> nextShape(currentShape.size());
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange<IntegerAttr>())) {
int64_t sourceIndex = attr.getInt();
nextIndices[destIndex] = currentIndices[sourceIndex];
nextShape[destIndex] = currentShape[sourceIndex];
}
currentIndices = std::move(nextIndices);
currentShape = std::move(nextShape);
continue;
}
if (isa<tensor::CastOp, tosa::ReshapeOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp>(op)) {
if (failed(reshapeToResultShape(op)))
return failure();
continue;
}
return failure();
}
mappedIndices.assign(currentIndices.begin(), currentIndices.end());
return success();
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static void
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
IRMapping mapping;
mapping.map(sourceValue, sourceValue);
clonedValue = sourceValue;
rewriter.setInsertionPointAfterValue(sourceValue);
for (Operation* op : helperChain) {
cloneMappedHelperOperands(op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
clonedValue = clonedOp->getResult(0);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static Value emitHostCopy(IRRewriter& rewriter,
Location loc,
Value outputTensor,
Value sourceValue,
int32_t hostTargetOffset,
int32_t deviceSourceOffset,
int32_t sizeInBytes) {
return PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
sourceValue,
rewriter.getI32IntegerAttr(hostTargetOffset),
rewriter.getI32IntegerAttr(deviceSourceOffset),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
} // namespace
void addReturnOutputBuffers(func::ReturnOp returnOp,
IRRewriter& rewriter,
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
outputTensors.reserve(returnOp->getNumOperands());
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
Value currentReturnValue = returnValue;
Operation* returnValueDefiningOp = currentReturnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back(
[currentReturnValue](IRRewriter& rewriter, Location loc) -> Value { return currentReturnValue; });
}
else {
auto outRankedTensorType = llvm::dyn_cast<RankedTensorType>(currentReturnValue.getType());
auto memRefType = MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
std::string outputBaseName = ("output_" + Twine(index)).str();
std::string outputName = makeUniqueSymbolName(returnOp->getParentOfType<ModuleOp>(), outputBaseName);
rewriter.setInsertionPoint(returnOp.getParentOp());
memref::GlobalOp::create(rewriter,
returnOp.getLoc(),
rewriter.getStringAttr(outputName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
outputTensors.push_back([memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
return toTensor.getResult();
});
}
}
}
ReturnPathLoweringResult lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
auto yieldType = cast<TensorType>(yieldValue.getType());
if (auto returnUse = analyzeReturnUse(result)) {
Value storedValue = yieldValue;
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
for (Operation* op : returnUse->helperChain)
markOpToRemove(state, op);
auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
if (auto storedOp = storedValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy(
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize));
return ReturnPathLoweringResult::Handled;
}
auto resultUses = result.getUses();
if (rangeLength(resultUses) == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize));
return ReturnPathLoweringResult::Handled;
}
}
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
for (Operation* concatOp : concatReturnUse->concatChain)
markOpToRemove(state, concatOp);
if (concatReturnUse->helperChain.empty()) {
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
emitHostCopy(rewriter,
loc,
outputTensor,
yieldValue,
static_cast<int32_t>(flatOffset * elementSize),
0,
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
return ReturnPathLoweringResult::Handled;
}
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
if (!storedType) {
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure;
}
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
for (auto [dim, idx] : llvm::enumerate(sourceIndices))
sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx;
SmallVector<int64_t> destinationIndices;
if (failed(mapIndicesThroughHelperChain(
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure;
}
SmallVector<OpFoldResult> extractOffsets;
SmallVector<OpFoldResult> extractSizes;
SmallVector<OpFoldResult> extractStrides;
extractOffsets.reserve(storedType.getRank());
extractSizes.reserve(storedType.getRank());
extractStrides.reserve(storedType.getRank());
for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) {
extractOffsets.push_back(rewriter.getIndexAttr(idx));
extractSizes.push_back(rewriter.getIndexAttr(1));
extractStrides.push_back(rewriter.getIndexAttr(1));
}
auto scalarTensorType =
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
auto elementSlice = tensor::ExtractSliceOp::create(
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
rewriter.setInsertionPointAfter(elementSlice);
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
outputTensor = emitHostCopy(rewriter,
loc,
outputTensor,
elementSlice.getResult(),
static_cast<int32_t>(destinationFlatOffset * elementSize),
0,
static_cast<int32_t>(elementSize));
}
return ReturnPathLoweringResult::Handled;
}
return ReturnPathLoweringResult::NotReturnPath;
}
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
if (!op)
return;
bool isExclusivelyOwnedByReturnChain = op->use_empty();
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
Operation* onlyUser = *op->getUsers().begin();
isExclusivelyOwnedByReturnChain =
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|| isReturnHelperChainOp(onlyUser);
}
if (!isExclusivelyOwnedByReturnChain)
return;
if (isReturnHelperChainOp(op)) {
Value source = op->getOperand(0);
markOpToRemove(state, op);
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
markOpToRemove(state, computeOp);
if (!computeOp.getInputs().empty())
for (Value input : computeOp.getInputs())
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
markOpToRemove(state, concatOp);
for (Value operand : concatOp.getOperands())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
markOpToRemove(state, concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
markOpToRemove(state, concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
}
};
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
auto loc = returnOp.getLoc();
for (auto it : llvm::enumerate(originalOperands)) {
size_t orderWithinReturn = it.index();
Operation* returnOperand = it.value().getDefiningOp();
rewriter.setInsertionPoint(returnOp);
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
}
}
} // namespace onnx_mlir
@@ -0,0 +1,37 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include <functional>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
struct ReturnPathState {
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
};
enum class ReturnPathLoweringResult {
Handled,
NotReturnPath,
Failure
};
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp,
mlir::IRRewriter& rewriter,
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors);
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
mlir::OpResult result,
mlir::Value yieldValue,
ReturnPathState& state,
mlir::IRRewriter& rewriter);
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
} // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,113 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
// Replaces concat-of-adjacent-slices with one packed slice to keep batch sends compact.
struct FoldConcatOfContiguousSlices : OpRewritePattern<tensor::ConcatOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ConcatOp op, PatternRewriter& rewriter) const override {
if (op.getDim() != 0)
return failure();
Value packed = createPackedExtractSliceTensor(op.getInputs(), rewriter, op.getLoc());
if (!packed)
return failure();
rewriter.replaceOp(op, packed);
return success();
}
};
} // namespace
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
packedShape[0] *= count;
return RankedTensorType::get(packedShape, elementType.getElementType());
}
Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) {
if (values.empty())
return {};
if (values.size() == 1)
return values.front();
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
if (!firstSliceOp)
return {};
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|| firstType.getRank() == 0)
return {};
auto hasStaticValues = [](ArrayRef<int64_t> values) {
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
};
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
return {};
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
int64_t rowsPerValue = firstSizes[0];
if (ShapedType::isDynamic(rowsPerValue))
return {};
for (size_t index = 1; index < values.size(); ++index) {
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|| !hasStaticValues(sliceOp.getStaticStrides()))
return {};
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
return {};
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
return {};
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
return {};
}
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(firstType.getRank());
sizes.reserve(firstType.getRank());
strides.reserve(firstType.getRank());
offsets.push_back(builder.getIndexAttr(firstOffsets[0]));
sizes.push_back(builder.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
strides.push_back(builder.getIndexAttr(firstStrides[0]));
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
offsets.push_back(builder.getIndexAttr(firstOffsets[dim]));
sizes.push_back(builder.getIndexAttr(firstSizes[dim]));
strides.push_back(builder.getIndexAttr(firstStrides[dim]));
}
bool coversWholeSource = packedType == sourceType;
for (int64_t dim = 0; coversWholeSource && dim < sourceType.getRank(); ++dim)
coversWholeSource = firstOffsets[dim] == 0 && firstStrides[dim] == 1;
if (coversWholeSource)
return firstSliceOp.getSource();
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
.getResult();
}
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConcatOfContiguousSlices>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -0,0 +1,13 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir
+34
View File
@@ -113,6 +113,18 @@ def PimSendBatchOp : PimOp<"send_batch", []> {
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
} }
def PimSendTensorBatchOp : PimOp<"send_tensor_batch", []> {
let summary = "Send equal contiguous chunks of one per-lane tensor from a batched core";
let arguments = (ins
PimTensor:$input,
DenseI32ArrayAttr:$targetCoreIds
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> { def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
let summary = "Receive a tensor from another core"; let summary = "Receive a tensor from another core";
@@ -181,6 +193,28 @@ def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
} }
def PimReceiveTensorBatchOp : PimOp<"receive_tensor_batch", [DestinationStyleOpInterface]> {
let summary = "Receive equal contiguous chunks into one per-lane tensor inside a batched core";
let arguments = (ins
PimTensor:$outputBuffer,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> { def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from host memory into device memory"; let summary = "Copy a memory region from host memory into device memory";
+64
View File
@@ -174,6 +174,33 @@ ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parser.resolveOperand(input, inputType, result.operands); return parser.resolveOperand(input, inputType, result.operands);
} }
void PimSendTensorBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult PimSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendTensorOp::print(OpAsmPrinter& printer) { void PimSendTensorOp::print(OpAsmPrinter& printer) {
printer << " "; printer << " ";
printer.printOperand(getInput()); printer.printOperand(getInput());
@@ -275,6 +302,43 @@ ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result
return success(); return success();
} }
void PimReceiveTensorBatchOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printer.printOperand(getOutputBuffer());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutputBuffer().getType());
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand outputBuffer;
Type outputBufferType;
Type outputType;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|| parser.parseType(outputType))
return failure();
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void PimConcatOp::print(OpAsmPrinter& printer) { void PimConcatOp::print(OpAsmPrinter& printer) {
printer << " axis " << getAxis() << " "; printer << " axis " << getAxis() << " ";
printCompressedValueSequence(printer, getInputs()); printCompressedValueSequence(printer, getInputs());
+44
View File
@@ -46,12 +46,47 @@ static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRe
return success(); return success();
} }
static LogicalResult
verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
if (!coreBatchOp)
return op->emitError() << kind << " must be nested inside pim.core_batch";
int32_t laneCount = coreBatchOp.getLaneCount();
if (laneCount <= 0)
return op->emitError() << kind << " requires a positive parent laneCount";
if (coreIds.size() % static_cast<size_t>(laneCount) != 0)
return op->emitError() << kind << " core id count must be divisible by the parent laneCount";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor or memref";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
return success();
}
} // namespace } // namespace
LogicalResult PimSendTensorOp::verify() { LogicalResult PimSendTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor"); return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
} }
LogicalResult PimSendTensorBatchOp::verify() {
return verifyTensorBatchCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor_batch");
}
LogicalResult PimReceiveTensorOp::verify() { LogicalResult PimReceiveTensorOp::verify() {
if (failed(verifyCompatibleShapedTypes( if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
@@ -60,6 +95,15 @@ LogicalResult PimReceiveTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor"); return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
} }
LogicalResult PimReceiveTensorBatchOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
return verifyTensorBatchCommunication(
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
}
LogicalResult PimConcatOp::verify() { LogicalResult PimConcatOp::verify() {
if (getInputs().empty()) if (getInputs().empty())
return emitError("requires at least one input"); return emitError("requires at least one input");
@@ -0,0 +1,40 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
using namespace mlir;
using namespace bufferization;
namespace onnx_mlir::pim {
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
return PimMemCopyOp::create(rewriter,
loc,
contiguousType,
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
FailureOr<Value>
getBufferOrValue(RewriterBase& rewriter, Value value, const BufferizationOptions& options, BufferizationState& state) {
if (isa<BufferLikeType>(value.getType()))
return value;
return getBuffer(rewriter, value, options, state);
}
} // namespace onnx_mlir::pim
@@ -0,0 +1,15 @@
#pragma once
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir::pim {
mlir::Value materializeContiguousMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
llvm::FailureOr<mlir::Value> getBufferOrValue(mlir::RewriterBase& rewriter,
mlir::Value value,
const mlir::bufferization::BufferizationOptions& options,
mlir::bufferization::BufferizationState& state);
} // namespace onnx_mlir::pim
@@ -4,6 +4,8 @@ add_public_tablegen_target(PimBufferizationIncGen)
add_pim_library(OMPimBufferization add_pim_library(OMPimBufferization
PimBufferizationPass.cpp PimBufferizationPass.cpp
BufferizationUtils.hpp
BufferizationUtils.cpp
OpBufferizationInterfaces.hpp OpBufferizationInterfaces.hpp
OpBufferizationInterfaces.cpp OpBufferizationInterfaces.cpp
Common.hpp Common.hpp
@@ -7,6 +7,7 @@
#include "OpBufferizationInterfaces.hpp" #include "OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
using namespace mlir; using namespace mlir;
using namespace bufferization; using namespace bufferization;
@@ -14,33 +15,6 @@ using namespace bufferization;
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace pim {
static Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
return PimMemCopyOp::create(rewriter,
loc,
contiguousType,
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
static FailureOr<Value>
getBufferOrValue(RewriterBase& rewriter, Value value, const BufferizationOptions& options, BufferizationState& state) {
if (isa<BufferLikeType>(value.getType()))
return value;
return getBuffer(rewriter, value, options, state);
}
struct MemCopyHostToDevOpInterface struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> { : DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op, LogicalResult bufferize(Operation* op,
@@ -201,6 +175,27 @@ struct ReceiveTensorOpInterface
} }
}; };
struct ReceiveTensorBatchOpInterface
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorBatchOpInterface, PimReceiveTensorBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveTensorBatchOp>(op);
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveTensorBatchOp>(
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
return success();
}
};
struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInterface, PimConcatOp> { struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInterface, PimConcatOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -308,6 +303,31 @@ struct SendBatchOpInterface : BufferizableOpInterface::ExternalModel<SendBatchOp
} }
}; };
struct SendTensorBatchOpInterface
: BufferizableOpInterface::ExternalModel<SendTensorBatchOpInterface, PimSendTensorBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto sendOp = cast<PimSendTensorBatchOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
replaceOpWithNewBufferizedOp<PimSendTensorBatchOp>(
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
return success();
}
};
struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface, PimCoreOp> { struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface, PimCoreOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
@@ -623,9 +643,11 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx); PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx); PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx); PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx);
PimReceiveTensorBatchOp::attachInterface<ReceiveTensorBatchOpInterface>(*ctx);
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx); PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
PimSendOp::attachInterface<SendOpInterface>(*ctx); PimSendOp::attachInterface<SendOpInterface>(*ctx);
PimSendBatchOp::attachInterface<SendBatchOpInterface>(*ctx); PimSendBatchOp::attachInterface<SendBatchOpInterface>(*ctx);
PimSendTensorBatchOp::attachInterface<SendTensorBatchOpInterface>(*ctx);
PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx); PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx);
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx); PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx); PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
+2 -2
View File
@@ -1,8 +1,8 @@
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir::spatial { namespace onnx_mlir::spatial {
+28 -33
View File
@@ -102,23 +102,6 @@ def SpatConcatOp : SpatOp<"concat", []> {
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
} }
def SpatMapOp : SpatOp<"map", [SingleBlock]> {
let summary = "Apply the same lane-local region to many independent tensors";
let arguments = (ins
Variadic<SpatTensor>:$inputs
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Communication // Communication
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -156,22 +139,25 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
}]; }];
} }
def SpatChannelSendManyOp : SpatOp<"channel_send_many", []> { def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
let summary = "Send multiple tensors through logical channels"; let summary = "Send equal contiguous chunks of one tensor through logical channels";
let arguments = (ins let arguments = (ins
DenseI64ArrayAttr:$channelIds, DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds, DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds, DenseI32ArrayAttr:$targetCoreIds,
Variadic<SpatTensor>:$inputs SpatTensor:$input
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$input attr-dict `:` type($input)
}];
} }
def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> { def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
let summary = "Receive multiple tensors from logical channels"; let summary = "Receive equal contiguous chunks of one tensor from logical channels";
let arguments = (ins let arguments = (ins
DenseI64ArrayAttr:$channelIds, DenseI64ArrayAttr:$channelIds,
@@ -180,11 +166,14 @@ def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> {
); );
let results = (outs let results = (outs
Variadic<SpatTensor>:$outputs SpatTensor:$output
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict `:` type($output)
}];
} }
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> { def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
@@ -201,18 +190,21 @@ def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
} }
def SpatChannelSendManyBatchOp : SpatOp<"channel_send_many_batch", []> { def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
let summary = "Send multiple per-lane tensors through logical channels in a batch body"; let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins let arguments = (ins
DenseI64ArrayAttr:$channelIds, DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds, DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds, DenseI32ArrayAttr:$targetCoreIds,
Variadic<SpatTensor>:$inputs SpatTensor:$input
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$input attr-dict `:` type($input)
}];
} }
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
@@ -232,8 +224,8 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
} }
def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> { def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> {
let summary = "Receive multiple per-lane tensors through logical channels in a batch body"; let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins let arguments = (ins
DenseI64ArrayAttr:$channelIds, DenseI64ArrayAttr:$channelIds,
@@ -242,11 +234,14 @@ def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> {
); );
let results = (outs let results = (outs
Variadic<SpatTensor>:$outputs SpatTensor:$output
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict `:` type($output)
}];
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
+2 -225
View File
@@ -7,8 +7,8 @@
#include <string> #include <string>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;
@@ -129,9 +129,8 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
if (parser.parseKeyword("axis") || parser.parseInteger(axis)) if (parser.parseKeyword("axis") || parser.parseInteger(axis))
return failure(); return failure();
if (parseCompressedOperandSequence(parser, inputs)) { if (parseCompressedOperandSequence(parser, inputs))
return failure(); return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList( || parseCompressedRepeatedList(
@@ -151,46 +150,6 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
return success(); return success();
} }
void SpatMapOp::print(OpAsmPrinter& printer) {
printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getInputs().front().getType());
printer << " -> ";
printer.printType(getOutputs().front().getType());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult SpatMapOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
Type inputType;
Type outputType;
if (parseArgumentBindings(parser, regionArgs, inputs))
return failure();
if (inputs.empty())
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|| parser.parseArrow() || parser.parseType(outputType))
return failure();
SmallVector<Type> inputTypes(inputs.size(), inputType);
SmallVector<Type> outputTypes(inputs.size(), outputType);
if (regionArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
applyArgumentTypes(inputTypes, regionArgs);
Region* body = result.addRegion();
return parser.parseRegion(*body, regionArgs);
}
void SpatCompute::print(OpAsmPrinter& printer) { void SpatCompute::print(OpAsmPrinter& printer) {
printer << " "; printer << " ";
printCompressedValueList(printer, getWeights(), ListDelimiter::Square); printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
@@ -357,97 +316,6 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
return parser.parseRegion(*body, regionArgs); return parser.parseRegion(*body, regionArgs);
} }
void SpatChannelSendManyOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getResultTypes());
}
ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<Type> outputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputTypes);
return success();
}
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) { void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
printer << " "; printer << " ";
printer.printOperand(getInput()); printer.printOperand(getInput());
@@ -494,55 +362,6 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
return parser.resolveOperand(input, inputType, result.operands); return parser.resolveOperand(input, inputType, result.operands);
} }
void SpatChannelSendManyBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult SpatChannelSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) { void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict( printer.printOptionalAttrDict(
@@ -584,47 +403,5 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
return success(); return success();
} }
void SpatChannelReceiveManyBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getResultTypes());
}
ParseResult SpatChannelReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<Type> outputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputTypes);
return success();
}
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir
+61 -70
View File
@@ -105,26 +105,28 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
return batchOp.getLaneCount(); return batchOp.getLaneCount();
} }
static LogicalResult verifyManyChannelSizes(Operation* op, static LogicalResult verifyTensorChannelSizes(Operation* op,
Type type,
ArrayRef<int64_t> channelIds, ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds, ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds, ArrayRef<int32_t> targetCoreIds,
size_t valueCount) { StringRef kind) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
if (channelIds.size() != valueCount) if (channelIds.empty())
return op->emitError("channel metadata length must match the number of values"); return op->emitError() << kind << " must carry at least one chunk";
return success();
}
static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) { auto shapedType = dyn_cast<ShapedType>(type);
if (types.empty()) if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " must carry at least one value"; return op->emitError() << kind << " requires a static shaped tensor";
Type firstType = types.front(); int64_t elementBits = shapedType.getElementTypeBitWidth();
for (Type type : types.drop_front()) if (elementBits <= 0 || elementBits % 8 != 0)
if (type != firstType) return op->emitError() << kind << " requires byte-sized elements";
return op->emitError() << kind << " values must all have the same type";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
return success(); return success();
} }
@@ -144,19 +146,33 @@ static LogicalResult verifyBatchChannelSizes(Operation* op,
return success(); return success();
} }
static LogicalResult verifyManyBatchChannelSizes(Operation* op, static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
Type type,
ArrayRef<int64_t> channelIds, ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds, ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds, ArrayRef<int32_t> targetCoreIds,
size_t valueCount) { StringRef kind) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op); auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount)) if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch"); return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.size() != valueCount * static_cast<size_t>(*laneCount)) if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
return op->emitError("channel metadata length must match the number of values times parent laneCount"); return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
return success(); return success();
} }
@@ -323,39 +339,6 @@ LogicalResult SpatConcatOp::verify() {
return success(); return success();
} }
LogicalResult SpatMapOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");
if (getOutputs().size() != getInputs().size())
return emitError("number of outputs must match number of inputs");
Type inputType = getInputs().front().getType();
for (Value input : getInputs().drop_front())
if (input.getType() != inputType)
return emitError("all inputs must have the same type");
Type outputType = getOutputs().front().getType();
for (Value output : getOutputs().drop_front())
if (output.getType() != outputType)
return emitError("all outputs must have the same type");
Block& block = getBody().front();
if (block.getNumArguments() != 1)
return emitError("body must have exactly one block argument");
if (block.getArgument(0).getType() != inputType)
return emitError("body block argument type must match input type");
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("body must terminate with spat.yield");
if (yieldOp.getNumOperands() != 1)
return emitError("body yield must produce exactly one value");
if (yieldOp.getOperand(0).getType() != outputType)
return emitError("body yield type must match output type");
return success();
}
LogicalResult SpatCompute::verify() { LogicalResult SpatCompute::verify() {
auto& block = getBody().front(); auto& block = getBody().front();
if (block.mightHaveTerminator()) { if (block.mightHaveTerminator()) {
@@ -397,40 +380,48 @@ LogicalResult SpatCompute::verify() {
return success(); return success();
} }
LogicalResult SpatChannelSendManyOp::verify() { LogicalResult SpatChannelSendTensorOp::verify() {
if (failed(verifyManyChannelSizes( return verifyTensorChannelSizes(getOperation(),
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size()))) getInput().getType(),
return failure(); getChannelIds(),
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many"); getSourceCoreIds(),
getTargetCoreIds(),
"channel_send_tensor");
} }
LogicalResult SpatChannelReceiveManyOp::verify() { LogicalResult SpatChannelReceiveTensorOp::verify() {
if (failed(verifyManyChannelSizes( return verifyTensorChannelSizes(getOperation(),
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size()))) getOutput().getType(),
return failure(); getChannelIds(),
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many"); getSourceCoreIds(),
getTargetCoreIds(),
"channel_receive_tensor");
} }
LogicalResult SpatChannelSendBatchOp::verify() { LogicalResult SpatChannelSendBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
} }
LogicalResult SpatChannelSendManyBatchOp::verify() { LogicalResult SpatChannelSendTensorBatchOp::verify() {
if (failed(verifyManyBatchChannelSizes( return verifyTensorBatchChannelSizes(getOperation(),
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size()))) getInput().getType(),
return failure(); getChannelIds(),
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many_batch"); getSourceCoreIds(),
getTargetCoreIds(),
"channel_send_tensor_batch");
} }
LogicalResult SpatChannelReceiveBatchOp::verify() { LogicalResult SpatChannelReceiveBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
} }
LogicalResult SpatChannelReceiveManyBatchOp::verify() { LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
if (failed(verifyManyBatchChannelSizes( return verifyTensorBatchChannelSizes(getOperation(),
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size()))) getOutput().getType(),
return failure(); getChannelIds(),
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many_batch"); getSourceCoreIds(),
getTargetCoreIds(),
"channel_receive_tensor_batch");
} }
LogicalResult SpatComputeBatch::verify() { LogicalResult SpatComputeBatch::verify() {
@@ -10,6 +10,7 @@
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <algorithm> #include <algorithm>
#include <cstdlib>
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <optional> #include <optional>
@@ -31,6 +32,8 @@ namespace {
using SpatCompute = onnx_mlir::spatial::SpatCompute; using SpatCompute = onnx_mlir::spatial::SpatCompute;
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch; using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
struct VirtualNode { struct VirtualNode {
SmallVector<size_t, 4> originalComputeIndices; SmallVector<size_t, 4> originalComputeIndices;
Weight weight = 0; Weight weight = 0;
@@ -719,11 +722,12 @@ DCPAnalysisResult DCPAnalysis::run() {
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges); VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
size_t iteration = 0; size_t iteration = 0;
bool debugCoarsening = isDcpCoarsenDebugEnabled();
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) { auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
size_t oldNodeCount = virtualGraph.nodes.size(); size_t oldNodeCount = virtualGraph.nodes.size();
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext()); WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
if (windowSchedule.mergeGroups.empty()) { if (windowSchedule.mergeGroups.empty()) {
if (oldNodeCount >= 200) if (debugCoarsening && oldNodeCount >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} " llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n", "groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
iteration, iteration,
@@ -737,7 +741,7 @@ DCPAnalysisResult DCPAnalysis::run() {
std::vector<size_t> oldToNewNode; std::vector<size_t> oldToNewNode;
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode)) if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
return false; return false;
if (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200) if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} " llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n", "groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
iteration, iteration,
@@ -755,7 +759,7 @@ DCPAnalysisResult DCPAnalysis::run() {
while (virtualGraph.nodes.size() > 1) { while (virtualGraph.nodes.size() > 1) {
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) { if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
if (virtualGraph.nodes.size() >= 200) if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv( llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size()); "[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
break; break;
@@ -764,7 +768,7 @@ DCPAnalysisResult DCPAnalysis::run() {
iteration++; iteration++;
TimingInfo timing = computeTiming(virtualGraph); TimingInfo timing = computeTiming(virtualGraph);
if (!timing.valid) { if (!timing.valid) {
if (virtualGraph.nodes.size() >= 200) if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv( llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size()); "[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
break; break;
@@ -776,7 +780,7 @@ DCPAnalysisResult DCPAnalysis::run() {
selectedNodes.append(criticalWindow.begin(), criticalWindow.end()); selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
if (selectedNodes.size() < 2) { if (selectedNodes.size() < 2) {
if (virtualGraph.nodes.size() >= 200) if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n", llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
iteration, iteration,
virtualGraph.nodes.size(), virtualGraph.nodes.size(),
@@ -786,7 +790,7 @@ DCPAnalysisResult DCPAnalysis::run() {
if (tryCoarsenSelectedNodes(selectedNodes)) if (tryCoarsenSelectedNodes(selectedNodes))
continue; continue;
if (virtualGraph.nodes.size() >= 200) if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv( llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size()); "[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
break; break;
@@ -59,11 +59,7 @@ struct DenseMapInfo<ComputeInstance> {
static ComputeInstance getTombstoneKey() { static ComputeInstance getTombstoneKey() {
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX}; return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
} }
static unsigned getHashValue(const ComputeInstance& v) { static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); }
return llvm::hash_combine(v.op, v.laneStart, v.laneCount); static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; }
}
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) {
return a == b;
}
}; };
} // namespace llvm } // namespace llvm
@@ -38,9 +38,11 @@ void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += t
void DcpProgressLogger::printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const { void DcpProgressLogger::printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const {
if (!logProgress) if (!logProgress)
return; return;
llvm::errs() << llvm::formatv( llvm::errs() << llvm::formatv("[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n",
"[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n", totalTasks,
totalTasks, readyCount, maxCpuCount, xbarsCapacity); readyCount,
maxCpuCount,
xbarsCapacity);
} }
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex, void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
@@ -72,8 +74,7 @@ void DcpProgressLogger::printProgress(
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks); double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
bool done = completedTasks == totalTasks; bool done = completedTasks == totalTasks;
llvm::errs() << llvm::formatv( llvm::errs() << llvm::formatv("[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
"[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
completedTasks, completedTasks,
totalTasks, totalTasks,
percent, percent,
@@ -100,9 +101,7 @@ void DcpProgressLogger::printProgress(size_t, CPU, int, size_t, size_t, bool) {}
#endif #endif
void dumpGraphDot(const std::vector<TaskDCP>& nodes, void dumpGraphDot(const std::vector<TaskDCP>& nodes, const std::vector<std::list<TaskDCP*>>& cpuTasks, CPU lastCpu) {
const std::vector<std::list<TaskDCP*>>& cpuTasks,
CPU lastCpu) {
static int dumpIndex = 0; static int dumpIndex = 0;
std::string outputDir = onnx_mlir::getOutputDir(); std::string outputDir = onnx_mlir::getOutputDir();
if (outputDir.empty()) if (outputDir.empty())
@@ -9,9 +9,9 @@
#include "Task.hpp" #include "Task.hpp"
#include "Utils.hpp" #include "Utils.hpp"
// Uncomment to enable DCP progress logging and per-phase profiling during // Define DCP_DEBUG_ENABLED locally when debugging DCP progress and per-phase
// development. When disabled the logger methods are no-ops and the helpers // profiling. In normal builds the logger methods are no-ops and helpers compile
// compile away. // away.
#define DCP_DEBUG_ENABLED #define DCP_DEBUG_ENABLED
#ifdef DCP_DEBUG_ENABLED #ifdef DCP_DEBUG_ENABLED
@@ -33,10 +33,11 @@ public:
void printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const; void printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const;
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const; void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
void printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount, void
size_t xbarsUsed, size_t xbarsAvailable, bool force); printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount, size_t xbarsUsed, size_t xbarsAvailable, bool force);
#ifdef DCP_DEBUG_ENABLED #ifdef DCP_DEBUG_ENABLED
private: private:
static std::string formatDuration(double seconds); static std::string formatDuration(double seconds);
@@ -51,8 +52,6 @@ private:
#endif #endif
}; };
void dumpGraphDot(const std::vector<TaskDCP>& nodes, void dumpGraphDot(const std::vector<TaskDCP>& nodes, const std::vector<std::list<TaskDCP*>>& cpuTasks, CPU lastCpu);
const std::vector<std::list<TaskDCP*>>& cpuTasks,
CPU lastCpu);
} // namespace dcp_graph } // namespace dcp_graph
@@ -149,14 +149,6 @@ static SmallVector<int32_t> getMaterializedBatchCoreIds(size_t startCpu, size_t
return coreIds; return coreIds;
} }
static SmallVector<int32_t> getBatchCoreIds(Operation* op, size_t laneCount) {
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
if (auto coreIdAttr = op->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return SmallVector<int32_t>(laneCount, static_cast<int32_t>(coreIdAttr.getInt()));
return {};
}
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) { static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(coreIdAttr.getInt()); return static_cast<int32_t>(coreIdAttr.getInt());
@@ -245,312 +237,6 @@ static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end(); return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
} }
static void sinkChannelsIntoBatchComputes(func::FuncOp funcOp,
IRRewriter& rewriter,
SmallVectorImpl<Operation*>& opsToErase,
int64_t& nextChannelId) {
SmallVector<SpatComputeBatch> batches(funcOp.getOps<SpatComputeBatch>());
for (auto batch : batches) {
if (batch.getInputs().empty() && batch.getResults().empty())
continue;
if (batch.getInputs().size() != static_cast<size_t>(batch.getLaneCount()))
continue;
if (batch.getResults().size() != static_cast<size_t>(batch.getLaneCount()))
continue;
SmallVector<spatial::SpatChannelReceiveOp> inputReceives;
inputReceives.reserve(batch.getInputs().size());
bool allInputsAreReceives = true;
for (Value input : batch.getInputs()) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
if (!receiveOp) {
allInputsAreReceives = false;
break;
}
inputReceives.push_back(receiveOp);
}
SmallVector<spatial::SpatChannelSendOp> resultSends;
resultSends.reserve(batch.getResults().size());
bool allResultsAreSingleSends = true;
for (Value result : batch.getResults()) {
if (!result.hasOneUse()) {
allResultsAreSingleSends = false;
break;
}
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(*result.getUsers().begin());
if (!sendOp) {
allResultsAreSingleSends = false;
break;
}
resultSends.push_back(sendOp);
}
if (!allInputsAreReceives || !allResultsAreSingleSends)
continue;
Block& oldBlock = batch.getBody().front();
if (oldBlock.getNumArguments() != 1)
continue;
SmallVector<Value> newWeights(batch.getWeights().begin(), batch.getWeights().end());
rewriter.setInsertionPointAfter(batch);
auto newBatch = SpatComputeBatch::create(rewriter,
batch.getLoc(),
TypeRange {},
rewriter.getI32IntegerAttr(batch.getLaneCount()),
ValueRange(newWeights),
ValueRange {});
newBatch.getProperties().setOperandSegmentSizes({static_cast<int>(newWeights.size()), 0});
SmallVector<int32_t> coreIds = getBatchCoreIds(batch, static_cast<size_t>(batch.getLaneCount()));
if (!coreIds.empty())
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
auto* newBlock =
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef<Location> {});
rewriter.setInsertionPointToStart(newBlock);
struct BatchReceiveEntry {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
};
SmallVector<BatchReceiveEntry> receiveEntries;
receiveEntries.reserve(inputReceives.size());
for (auto receiveOp : inputReceives)
receiveEntries.push_back({receiveOp.getChannelId(), receiveOp.getSourceCoreId(), receiveOp.getTargetCoreId()});
llvm::stable_sort(receiveEntries, [](const BatchReceiveEntry& lhs, const BatchReceiveEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> receiveChannelIds;
SmallVector<int32_t> receiveSourceCoreIds;
SmallVector<int32_t> receiveTargetCoreIds;
receiveChannelIds.reserve(receiveEntries.size());
receiveSourceCoreIds.reserve(receiveEntries.size());
receiveTargetCoreIds.reserve(receiveEntries.size());
for (const BatchReceiveEntry& entry : receiveEntries) {
(void) entry;
receiveChannelIds.push_back(nextChannelId++);
receiveSourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
receiveTargetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
batch.getLoc(),
oldBlock.getArgument(0).getType(),
rewriter.getDenseI64ArrayAttr(receiveChannelIds),
rewriter.getDenseI32ArrayAttr(receiveSourceCoreIds),
rewriter.getDenseI32ArrayAttr(receiveTargetCoreIds));
IRMapping mapper;
mapper.map(oldBlock.getArgument(0), batchReceive.getOutput());
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : oldBlock) {
if (&op == oldYield)
continue;
rewriter.clone(op, mapper);
}
Value sendInput = mapper.lookup(oldYield.getOperand(0));
struct BatchSendEntry {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
};
SmallVector<BatchSendEntry> sendEntries;
sendEntries.reserve(resultSends.size());
for (auto sendOp : resultSends)
sendEntries.push_back({sendOp.getChannelId(), sendOp.getSourceCoreId(), sendOp.getTargetCoreId()});
llvm::stable_sort(sendEntries, [](const BatchSendEntry& lhs, const BatchSendEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> sendChannelIds;
SmallVector<int32_t> sendSourceCoreIds;
SmallVector<int32_t> sendTargetCoreIds;
sendChannelIds.reserve(sendEntries.size());
sendSourceCoreIds.reserve(sendEntries.size());
sendTargetCoreIds.reserve(sendEntries.size());
for (const BatchSendEntry& entry : sendEntries) {
(void) entry;
sendChannelIds.push_back(nextChannelId++);
sendSourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
sendTargetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
spatial::SpatChannelSendBatchOp::create(rewriter,
batch.getLoc(),
rewriter.getDenseI64ArrayAttr(sendChannelIds),
rewriter.getDenseI32ArrayAttr(sendSourceCoreIds),
rewriter.getDenseI32ArrayAttr(sendTargetCoreIds),
sendInput);
spatial::SpatYieldOp::create(rewriter, batch.getLoc(), ValueRange {});
for (auto receiveOp : inputReceives)
opsToErase.push_back(receiveOp);
for (auto sendOp : resultSends)
opsToErase.push_back(sendOp);
opsToErase.push_back(batch);
}
}
void sinkChannelsIntoComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
SmallVector<Operation*> opsToErase;
for (auto compute : computes) {
SmallVector<unsigned> keptInputIndices;
SmallVector<unsigned> keptResultIndices;
SmallVector<spatial::SpatChannelReceiveOp> internalizedReceives(compute.getInputs().size());
SmallVector<SmallVector<spatial::SpatChannelSendOp>> resultSendOps(compute.getNumResults());
bool needsRewrite = false;
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
if (!receiveOp) {
keptInputIndices.push_back(inputIndex);
continue;
}
internalizedReceives[inputIndex] = receiveOp;
opsToErase.push_back(receiveOp);
needsRewrite = true;
}
for (auto [resultIndex, result] : llvm::enumerate(compute.getResults())) {
bool hasNonSendUser = false;
for (Operation* user : result.getUsers()) {
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(user)) {
resultSendOps[resultIndex].push_back(sendOp);
opsToErase.push_back(sendOp);
needsRewrite = true;
continue;
}
hasNonSendUser = true;
}
if (hasNonSendUser || resultSendOps[resultIndex].empty())
keptResultIndices.push_back(resultIndex);
}
if (!needsRewrite)
continue;
SmallVector<Value> newOperands;
SmallVector<Type> newResultTypes;
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
newOperands.reserve(compute.getNumOperands());
newResultTypes.reserve(keptResultIndices.size());
newBlockArgTypes.reserve(keptInputIndices.size());
newBlockArgLocs.reserve(keptInputIndices.size());
for (Value weight : compute.getWeights())
newOperands.push_back(weight);
for (unsigned inputIndex : keptInputIndices) {
Value input = compute.getInputs()[inputIndex];
newOperands.push_back(input);
newBlockArgTypes.push_back(input.getType());
newBlockArgLocs.push_back(compute.getLoc());
}
for (unsigned resultIndex : keptResultIndices)
newResultTypes.push_back(compute.getResult(resultIndex).getType());
rewriter.setInsertionPointAfter(compute);
auto newCompute =
SpatCompute::create(rewriter, compute.getLoc(), TypeRange(newResultTypes), ValueRange(newOperands));
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(compute.getWeights().size()), static_cast<int>(keptInputIndices.size())});
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, coreIdAttr);
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
IRMapping mapper;
for (auto [mappedIndex, inputIndex] : llvm::enumerate(keptInputIndices))
mapper.map(compute.getBody().front().getArgument(inputIndex), newBlock->getArgument(mappedIndex));
rewriter.setInsertionPointToStart(newBlock);
for (auto [inputIndex, receiveOp] : llvm::enumerate(internalizedReceives)) {
if (!receiveOp)
continue;
auto internalReceive = spatial::SpatChannelReceiveOp::create(rewriter,
receiveOp.getLoc(),
receiveOp.getResult().getType(),
receiveOp.getChannelIdAttr(),
receiveOp.getSourceCoreIdAttr(),
receiveOp.getTargetCoreIdAttr());
mapper.map(compute.getBody().front().getArgument(inputIndex), internalReceive.getResult());
}
auto oldYieldOp = cast<spatial::SpatYieldOp>(compute.getBody().front().getTerminator());
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : compute.getBody().front()) {
if (&op == oldYieldOp)
continue;
rewriter.clone(op, mapper);
}
for (auto [resultIndex, sendOps] : llvm::enumerate(resultSendOps)) {
if (sendOps.empty())
continue;
Value yieldedValue = mapper.lookup(oldYieldOp.getOperand(resultIndex));
for (auto sendOp : sendOps)
spatial::SpatChannelSendOp::create(rewriter,
sendOp.getLoc(),
sendOp.getChannelIdAttr(),
sendOp.getSourceCoreIdAttr(),
sendOp.getTargetCoreIdAttr(),
yieldedValue);
}
SmallVector<Value> keptYieldOperands;
keptYieldOperands.reserve(keptResultIndices.size());
for (unsigned resultIndex : keptResultIndices)
keptYieldOperands.push_back(mapper.lookup(oldYieldOp.getOperand(resultIndex)));
spatial::SpatYieldOp::create(rewriter, oldYieldOp.getLoc(), ValueRange(keptYieldOperands));
for (auto [newResultIndex, oldResultIndex] : llvm::enumerate(keptResultIndices))
compute.getResult(oldResultIndex).replaceAllUsesWith(newCompute.getResult(newResultIndex));
opsToErase.push_back(compute);
}
sinkChannelsIntoBatchComputes(funcOp, rewriter, opsToErase, nextChannelId);
SmallVector<Operation*> pendingRemovals(opsToErase.begin(), opsToErase.end());
while (!pendingRemovals.empty()) {
bool erasedAny = false;
for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) {
if (!(*it)->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(*it);
it = pendingRemovals.erase(it);
erasedAny = true;
}
if (erasedAny)
continue;
for (Operation* op : pendingRemovals)
op->emitError("failed to sink channel op into compute");
llvm_unreachable("channel sinking left cyclic top-level dependencies");
}
}
void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext()); IRRewriter rewriter(funcOp.getContext());
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>()); SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
@@ -1280,6 +966,7 @@ public:
void runOnOperation() override { void runOnOperation() override {
mergeTriviallyConnectedComputes(getOperation()); mergeTriviallyConnectedComputes(getOperation());
if (std::getenv("DCP_MOTIF_PROFILE"))
emitMotifProfile(getOperation()); emitMotifProfile(getOperation());
func::FuncOp func = getOperation(); func::FuncOp func = getOperation();
@@ -1718,17 +1405,12 @@ public:
for (Operation* user : result.getUsers()) for (Operation* user : result.getUsers())
remainingUsers.push_back(user); remainingUsers.push_back(user);
if (!remainingUsers.empty()) { if (!remainingUsers.empty()) {
llvm::errs() << "[MergeComputeNodesPass] refusing to erase op with remaining uses: " << op->getName() << "\n"; InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup")
llvm::errs() << " erase-set: " << (allOpsToErase.contains(op) ? "yes" : "no") << "\n"; << "; erase-set=" << (allOpsToErase.contains(op) ? "yes" : "no");
op->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions());
llvm::errs() << "\n";
for (Operation* user : remainingUsers) { for (Operation* user : remainingUsers) {
llvm::errs() << " user: " << user->getName() diagnostic.attachNote(user->getLoc())
<< " erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no") << "\n"; << "remaining user " << user->getName() << "; erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no");
user->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions());
llvm::errs() << "\n";
} }
op->emitOpError("still has uses during per-cpu merge cleanup");
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
@@ -40,6 +41,176 @@ struct RegularChunk {
Value output; Value output;
}; };
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
packedShape[0] *= count;
return RankedTensorType::get(packedShape, elementType.getElementType());
}
static Value
extractPackedChunk(Value packedValue, RankedTensorType chunkType, unsigned index, IRRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(chunkType.getRank());
sizes.reserve(chunkType.getRank());
strides.reserve(chunkType.getRank());
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(0)));
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(dim)));
strides.push_back(rewriter.getIndexAttr(1));
}
return tensor::ExtractSliceOp::create(rewriter, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
}
static Value createPackedExtractRowsSlice(
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
return {};
int64_t rowsPerValue = rowType.getDimSize(0);
if (ShapedType::isDynamic(rowsPerValue))
return {};
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(inputType.getRank());
sizes.reserve(inputType.getRank());
strides.reserve(inputType.getRank());
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
strides.push_back(rewriter.getIndexAttr(1));
}
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult();
}
static Value createPackedExtractSliceTensor(ValueRange values, IRRewriter& rewriter, Location loc) {
if (values.empty())
return {};
if (values.size() == 1)
return values.front();
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
if (!firstSliceOp)
return {};
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|| firstType.getRank() == 0)
return {};
auto hasStaticValues = [](ArrayRef<int64_t> values) {
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
};
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
return {};
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
int64_t rowsPerValue = firstSizes[0];
if (ShapedType::isDynamic(rowsPerValue))
return {};
for (size_t index = 1; index < values.size(); ++index) {
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|| !hasStaticValues(sliceOp.getStaticStrides()))
return {};
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
return {};
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
return {};
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
return {};
}
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(firstType.getRank());
sizes.reserve(firstType.getRank());
strides.reserve(firstType.getRank());
offsets.push_back(rewriter.getIndexAttr(firstOffsets[0]));
sizes.push_back(rewriter.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
strides.push_back(rewriter.getIndexAttr(firstStrides[0]));
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(firstOffsets[dim]));
sizes.push_back(rewriter.getIndexAttr(firstSizes[dim]));
strides.push_back(rewriter.getIndexAttr(firstStrides[dim]));
}
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
.getResult();
}
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
if (values.empty())
return false;
auto firstResult = dyn_cast<OpResult>(values.front());
if (!firstResult)
return false;
owner = firstResult.getOwner();
startIndex = firstResult.getResultNumber();
for (auto [index, value] : llvm::enumerate(values)) {
auto result = dyn_cast<OpResult>(value);
if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index)
return false;
}
return true;
}
static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) {
if (values.empty())
return {};
if (Value packedSlice = createPackedExtractSliceTensor(values, rewriter, loc))
return packedSlice;
Operation* owner = nullptr;
unsigned startIndex = 0;
if (getContiguousOpResults(values, owner, startIndex))
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
return createPackedExtractRowsSlice(
extractRowsOp, startIndex, static_cast<unsigned>(values.size()), rewriter, loc);
auto firstType = dyn_cast<RankedTensorType>(values.front().getType());
if (!firstType || !firstType.hasStaticShape() || firstType.getRank() == 0)
return {};
if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; }))
return {};
return tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, values).getResult();
}
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) { static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand
&& lhs.resultType == rhs.resultType; && lhs.resultType == rhs.resultType;
@@ -89,45 +260,97 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
return chunk; return chunk;
} }
static void buildRegularMapBody(spatial::SpatMapOp mapOp, const RegularChunk& anchorChunk, IRRewriter& rewriter) { static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
auto* block = rewriter.createBlock( assert(!run.empty() && "expected a non-empty regular chunk run");
&mapOp.getBody(), mapOp.getBody().end(), TypeRange {anchorChunk.input.getType()}, {anchorChunk.startOp->getLoc()}); const RegularChunk& anchorChunk = run.front();
rewriter.setInsertionPointToEnd(block);
SmallVector<Value> inputs;
inputs.reserve(run.size());
for (const RegularChunk& chunk : run)
inputs.push_back(chunk.input);
rewriter.setInsertionPoint(anchorChunk.startOp);
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
if (!packedInput)
return;
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(run.size()));
auto packedInit = tensor::EmptyOp::create(
rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType());
auto zero = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 0);
auto upper = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), run.size());
auto step = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 1);
auto loop =
scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
{
OpBuilder::InsertionGuard guard(rewriter);
Block* loopBlock = loop.getBody();
rewriter.setInsertionPointToStart(loopBlock);
Value iv = loopBlock->getArgument(0);
Value acc = loopBlock->getArgument(1);
Value inputRowOffset = iv;
if (inputType.getDimSize(0) != 1) {
auto rowsPerValue =
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), inputType.getDimSize(0));
inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
}
SmallVector<OpFoldResult> extractOffsets;
SmallVector<OpFoldResult> extractSizes;
SmallVector<OpFoldResult> extractStrides;
extractOffsets.push_back(inputRowOffset);
extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(0)));
extractStrides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
extractOffsets.push_back(rewriter.getIndexAttr(0));
extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
extractStrides.push_back(rewriter.getIndexAttr(1));
}
auto inputSlice = tensor::ExtractSliceOp::create(
rewriter, anchorChunk.startOp->getLoc(), inputType, packedInput, extractOffsets, extractSizes, extractStrides);
IRMapping mapping; IRMapping mapping;
mapping.map(anchorChunk.input, block->getArgument(0)); mapping.map(anchorChunk.input, inputSlice.getResult());
for (Operation* op : anchorChunk.ops) { for (Operation* op : anchorChunk.ops) {
Operation* cloned = rewriter.clone(*op, mapping); Operation* cloned = rewriter.clone(*op, mapping);
for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults())) for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults()))
mapping.map(oldResult, newResult); mapping.map(oldResult, newResult);
} }
spatial::SpatYieldOp::create( Value mappedOutput = mapping.lookup(anchorChunk.output);
rewriter, anchorChunk.startOp->getLoc(), ValueRange {mapping.lookup(anchorChunk.output)}); Value outputRowOffset = iv;
} if (outputType.getDimSize(0) != 1) {
auto rowsPerValue =
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) { arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), outputType.getDimSize(0));
assert(!run.empty() && "expected a non-empty regular chunk run"); outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
const RegularChunk& anchorChunk = run.front();
SmallVector<Value> inputs;
SmallVector<Type> outputTypes;
inputs.reserve(run.size());
outputTypes.reserve(run.size());
for (const RegularChunk& chunk : run) {
inputs.push_back(chunk.input);
outputTypes.push_back(chunk.output.getType());
} }
rewriter.setInsertionPoint(anchorChunk.startOp); SmallVector<OpFoldResult> insertOffsets;
auto mapOp = SmallVector<OpFoldResult> insertSizes;
spatial::SpatMapOp::create(rewriter, anchorChunk.startOp->getLoc(), TypeRange(outputTypes), ValueRange(inputs)); SmallVector<OpFoldResult> insertStrides;
buildRegularMapBody(mapOp, anchorChunk, rewriter); insertOffsets.push_back(outputRowOffset);
insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(0)));
insertStrides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < outputType.getRank(); ++dim) {
insertOffsets.push_back(rewriter.getIndexAttr(0));
insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim)));
insertStrides.push_back(rewriter.getIndexAttr(1));
}
auto inserted = tensor::InsertSliceOp::create(
rewriter, anchorChunk.startOp->getLoc(), mappedOutput, acc, insertOffsets, insertSizes, insertStrides);
scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult());
}
for (auto [index, chunk] : llvm::enumerate(run)) { for (auto [index, chunk] : llvm::enumerate(run)) {
Value replacement = extractPackedChunk(
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
Value output = chunk.output; Value output = chunk.output;
output.replaceAllUsesWith(mapOp.getResult(index)); output.replaceAllUsesWith(replacement);
} }
SmallVector<Operation*> opsToErase; SmallVector<Operation*> opsToErase;
@@ -178,28 +401,29 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
SmallVector<int64_t> channelIds; SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds; SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds; SmallVector<int32_t> targetCoreIds;
SmallVector<Type> outputTypes;
channelIds.reserve(sortedEntries.size()); channelIds.reserve(sortedEntries.size());
sourceCoreIds.reserve(sortedEntries.size()); sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size()); targetCoreIds.reserve(sortedEntries.size());
outputTypes.reserve(sortedEntries.size());
for (ReceiveEntry& entry : sortedEntries) { for (ReceiveEntry& entry : sortedEntries) {
(void) entry; (void) entry;
channelIds.push_back(nextChannelId++); channelIds.push_back(nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId)); sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId)); targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
outputTypes.push_back(entry.op.getOutput().getType());
} }
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
rewriter.setInsertionPoint(run.front()); rewriter.setInsertionPoint(run.front());
auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter, auto compactReceive =
spatial::SpatChannelReceiveTensorOp::create(rewriter,
run.front().getLoc(), run.front().getLoc(),
TypeRange(outputTypes), packedType,
rewriter.getDenseI64ArrayAttr(channelIds), rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds), rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds)); rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries)) for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex)); entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
for (auto op : run) for (auto op : run)
rewriter.eraseOp(op); rewriter.eraseOp(op);
@@ -255,12 +479,14 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
} }
rewriter.setInsertionPoint(run.front()); rewriter.setInsertionPoint(run.front());
spatial::SpatChannelSendManyOp::create(rewriter, Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
if (packedInput) {
spatial::SpatChannelSendTensorOp::create(rewriter,
run.front().getLoc(), run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds), rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds), rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds), rewriter.getDenseI32ArrayAttr(targetCoreIds),
ValueRange(inputs)); packedInput);
for (auto op : run) for (auto op : run)
rewriter.eraseOp(op); rewriter.eraseOp(op);
@@ -268,6 +494,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
continue; continue;
} }
} }
}
++it; ++it;
} }
@@ -297,25 +524,25 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
SmallVector<int64_t> channelIds; SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds; SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds; SmallVector<int32_t> targetCoreIds;
SmallVector<Type> outputTypes;
outputTypes.reserve(run.size());
for (auto op : run) { for (auto op : run) {
llvm::append_range(channelIds, op.getChannelIds()); llvm::append_range(channelIds, op.getChannelIds());
llvm::append_range(sourceCoreIds, op.getSourceCoreIds()); llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
llvm::append_range(targetCoreIds, op.getTargetCoreIds()); llvm::append_range(targetCoreIds, op.getTargetCoreIds());
outputTypes.push_back(op.getOutput().getType());
} }
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
rewriter.setInsertionPoint(run.front()); rewriter.setInsertionPoint(run.front());
auto compactReceive = auto compactReceive =
spatial::SpatChannelReceiveManyBatchOp::create(rewriter, spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
run.front().getLoc(), run.front().getLoc(),
TypeRange(outputTypes), packedType,
rewriter.getDenseI64ArrayAttr(channelIds), rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds), rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds)); rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [index, op] : llvm::enumerate(run)) for (auto [index, op] : llvm::enumerate(run))
op.getOutput().replaceAllUsesWith(compactReceive.getResult(index)); op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
for (auto op : run) for (auto op : run)
rewriter.eraseOp(op); rewriter.eraseOp(op);
@@ -352,12 +579,14 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
} }
rewriter.setInsertionPoint(run.front()); rewriter.setInsertionPoint(run.front());
spatial::SpatChannelSendManyBatchOp::create(rewriter, Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
if (packedInput) {
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
run.front().getLoc(), run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds), rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds), rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds), rewriter.getDenseI32ArrayAttr(targetCoreIds),
ValueRange(inputs)); packedInput);
for (auto op : run) for (auto op : run)
rewriter.eraseOp(op); rewriter.eraseOp(op);
@@ -365,6 +594,7 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
continue; continue;
} }
} }
}
++it; ++it;
} }
@@ -1,9 +1,9 @@
#include "Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h" #include "llvm/ADT/Hashing.h"
#include "Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
@@ -31,7 +31,8 @@ struct DenseSubviewKeyInfo {
static unsigned getHashValue(const DenseSubviewKey& key) { static unsigned getHashValue(const DenseSubviewKey& key) {
return static_cast<unsigned>( return static_cast<unsigned>(
llvm::hash_combine(key.source, llvm::hash_combine_range(key.offsets.begin(), key.offsets.end()), llvm::hash_combine(key.source,
llvm::hash_combine_range(key.offsets.begin(), key.offsets.end()),
llvm::hash_combine_range(key.resultShape.begin(), key.resultShape.end()))); llvm::hash_combine_range(key.resultShape.begin(), key.resultShape.end())));
} }
@@ -98,16 +99,16 @@ memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
alignment); alignment);
} }
FailureOr<DenseElementsAttr> foldDenseSubview(DenseElementsAttr denseAttr, FailureOr<DenseElementsAttr>
ArrayRef<int64_t> staticOffsets, foldDenseSubview(DenseElementsAttr denseAttr, ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> resultShape) {
ArrayRef<int64_t> resultShape) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType()); auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() != static_cast<int64_t>(staticOffsets.size()) if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() != static_cast<int64_t>(staticOffsets.size())
|| sourceType.getRank() != static_cast<int64_t>(resultShape.size())) || sourceType.getRank() != static_cast<int64_t>(resultShape.size()))
return failure(); return failure();
static DenseMap<DenseSubviewKey, DenseElementsAttr, DenseSubviewKeyInfo> cache; static DenseMap<DenseSubviewKey, DenseElementsAttr, DenseSubviewKeyInfo> cache;
DenseSubviewKey key {denseAttr, SmallVector<int64_t>(staticOffsets.begin(), staticOffsets.end()), DenseSubviewKey key {denseAttr,
SmallVector<int64_t>(staticOffsets.begin(), staticOffsets.end()),
SmallVector<int64_t>(resultShape.begin(), resultShape.end())}; SmallVector<int64_t>(resultShape.begin(), resultShape.end())};
if (auto cached = cache.find(key); cached != cache.end()) if (auto cached = cache.find(key); cached != cache.end())
return cached->second; return cached->second;
@@ -152,6 +153,30 @@ FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value)
return denseAttr; return denseAttr;
} }
FailureOr<DenseElementsAttr> foldDenseSourceToType(ModuleOp moduleOp, Value source, MemRefType resultType) {
auto srcSubview = getStaticSubviewInfo(source);
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(source);
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
if (failed(denseAttr))
return failure();
if (succeeded(srcSubview)) {
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
if (failed(staticOffsets))
return failure();
return foldDenseSubview(*denseAttr, *staticOffsets, resultType.getShape());
}
auto resultTensorType = RankedTensorType::get(resultType.getShape(), resultType.getElementType());
if (resultTensorType != denseAttr->getType())
return failure();
return *denseAttr;
}
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) { FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
value = stripMemRefViewOps(value); value = stripMemRefViewOps(value);
auto subviewOp = value.getDefiningOp<memref::SubViewOp>(); auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
@@ -36,6 +36,9 @@ llvm::FailureOr<mlir::DenseElementsAttr> foldDenseSubview(mlir::DenseElementsAtt
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value); llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
llvm::FailureOr<mlir::DenseElementsAttr>
foldDenseSourceToType(mlir::ModuleOp moduleOp, mlir::Value source, mlir::MemRefType resultType);
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value); llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic. /// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
@@ -90,6 +90,7 @@ static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
return attr; return attr;
} }
// Folds constant linalg fills inside cores into private globals plus device copies.
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> { struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -249,6 +250,7 @@ static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, M
return DenseElementsAttr::get(resultTensorType, resultValues); return DenseElementsAttr::get(resultTensorType, resultValues);
} }
// Folds transposes of constant globals so weight-only transposes stay host-side.
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> { struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -304,11 +306,9 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
rewriter.setInsertionPoint(transposeOp); rewriter.setInsertionPoint(transposeOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName()); auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
bool isAlwaysWeight = bool isAlwaysWeight = !transposeOp->getUsers().empty()
!transposeOp->getUsers().empty() && llvm::all_of(transposeOp->getUsers(),
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { [](Operation* user) { return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); });
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
});
if (isAlwaysWeight) { if (isAlwaysWeight) {
markWeightAlways(newGlobal); markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal); markWeightAlways(newGetGlobal);
@@ -330,6 +330,7 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
} }
}; };
// Collapses fill-and-copy allocation chains into one folded constant global.
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> { struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -367,9 +368,8 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
} }
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) { if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
return llvm::all_of(castOp->getUsers(), [](Operation* user) { return llvm::all_of(castOp->getUsers(),
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); [](Operation* user) { return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); });
});
})) { })) {
allLiveUsersAreCoreOps = false; allLiveUsersAreCoreOps = false;
} }
@@ -417,6 +417,7 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
} }
}; };
// Converts host copies from dense globals into direct folded globals.
struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> { struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -431,37 +432,14 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
if (!allocType || !allocType.hasStaticShape()) if (!allocType || !allocType.hasStaticShape())
return failure(); return failure();
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
auto moduleOp = copyOp->getParentOfType<ModuleOp>(); auto moduleOp = copyOp->getParentOfType<ModuleOp>();
if (!moduleOp) if (!moduleOp)
return failure(); return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource); auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType);
if (failed(denseAttr)) if (failed(foldedAttr))
return failure(); return failure();
DenseElementsAttr foldedAttr;
if (succeeded(srcSubview)) {
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
if (failed(staticOffsets))
return failure();
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
if (failed(maybeFoldedAttr))
return failure();
foldedAttr = *maybeFoldedAttr;
}
else {
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
if (resultTensorType != denseAttr->getType())
return failure();
foldedAttr = *denseAttr;
}
bool allLiveUsersAreCores = true; bool allLiveUsersAreCores = true;
for (Operation* user : allocOp->getUsers()) { for (Operation* user : allocOp->getUsers()) {
if (user == copyOp) if (user == copyOp)
@@ -477,7 +455,7 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
return failure(); return failure();
} }
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_host_copy"); auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_host_copy");
if (allLiveUsersAreCores) if (allLiveUsersAreCores)
markWeightAlways(newGlobal); markWeightAlways(newGlobal);
@@ -494,6 +472,7 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
} }
}; };
// Converts PIM copies from dense globals into direct folded globals before codegen.
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> { struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -511,37 +490,14 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0) if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0)
return failure(); return failure();
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
auto moduleOp = copyOp->getParentOfType<ModuleOp>(); auto moduleOp = copyOp->getParentOfType<ModuleOp>();
if (!moduleOp) if (!moduleOp)
return failure(); return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource); auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType);
if (failed(denseAttr)) if (failed(foldedAttr))
return failure(); return failure();
DenseElementsAttr foldedAttr;
if (succeeded(srcSubview)) {
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
if (failed(staticOffsets))
return failure();
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
if (failed(maybeFoldedAttr))
return failure();
foldedAttr = *maybeFoldedAttr;
}
else {
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
if (resultTensorType != denseAttr->getType())
return failure();
foldedAttr = *denseAttr;
}
bool allLiveUsersAreCores = true; bool allLiveUsersAreCores = true;
for (Operation* user : allocOp->getUsers()) { for (Operation* user : allocOp->getUsers()) {
if (user == copyOp) if (user == copyOp)
@@ -557,7 +513,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
return failure(); return failure();
} }
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp"); auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_memcp");
if (allLiveUsersAreCores) if (allLiveUsersAreCores)
markWeightAlways(newGlobal); markWeightAlways(newGlobal);
@@ -577,13 +533,11 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
} // namespace } // namespace
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) { void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
patterns patterns.add<FoldConstantTransposePattern,
.add<FoldConstantTransposePattern,
FoldConstantAllocPattern, FoldConstantAllocPattern,
FoldConstantCoreMapPattern, FoldConstantCoreMapPattern,
FoldConstantHostCopyPattern, FoldConstantHostCopyPattern,
FoldConstantMemCpPattern>( FoldConstantMemCpPattern>(patterns.getContext());
patterns.getContext());
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -128,6 +128,7 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
return success(); return success();
} }
// Splits core copies through subviews into contiguous copy chunks for codegen.
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> { struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -162,6 +163,7 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
} }
}; };
// Splits host-to-device subview loads into contiguous copy chunks.
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> { struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -193,6 +195,7 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
} }
}; };
// Splits device-to-host subview stores into contiguous copy chunks.
struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDevToHostOp> { struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDevToHostOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -224,6 +227,7 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDe
} }
}; };
// Folds constant subviews used as core weights into standalone globals.
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> { struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -1,5 +1,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
@@ -9,6 +10,8 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/MathExtras.h" #include "llvm/Support/MathExtras.h"
#include <type_traits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -21,6 +24,8 @@ namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op)) if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1; return operandIndex == 1;
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op)) if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0; return operandIndex == 0;
return false; return false;
@@ -33,6 +38,91 @@ static int64_t getValueSizeInBytes(Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8; return type.getNumElements() * type.getElementTypeBitWidth() / 8;
} }
template <typename CoreOpTy>
static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter, bool& hasFailure) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
SmallVector<Operation*> ops;
coreOp.getBody().front().walk([&](Operation* op) {
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
ops.push_back(op);
});
for (Operation* op : ops) {
for (OpOperand& operand : op->getOpOperands()) {
Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber()))
continue;
auto resolvedAddress = resolveContiguousAddress(originalValue);
if (failed(resolvedAddress))
continue;
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
if (!getGlobalOp)
continue;
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
if (!originalType || !originalType.hasStaticShape()) {
op->emitOpError("host constant materialization requires a static memref operand");
hasFailure = true;
continue;
}
auto& cachedByOffset = materializedValues[resolvedAddress->base];
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
auto cachedValue = cachedByType.find(originalType);
if (cachedValue != cachedByType.end()) {
operand.set(cachedValue->second);
continue;
}
int64_t totalBytes = getValueSizeInBytes(originalValue);
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
hasFailure = true;
continue;
}
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
rewriter.setInsertionPoint(op);
Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType);
Value deviceDst = localAlloc;
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
Value copiedValue;
if constexpr (std::is_same_v<CoreOpTy, pim::PimCoreBatchOp>) {
copiedValue = pim::PimMemCopyHostToDevBatchOp::create(
rewriter,
op->getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
.getOutput();
}
else {
copiedValue = pim::PimMemCopyHostToDevOp::create(
rewriter,
op->getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
.getOutput();
}
cachedByType[originalType] = copiedValue;
operand.set(copiedValue);
}
}
}
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> { struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
@@ -50,71 +140,11 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
if (funcOp.isExternal()) if (funcOp.isExternal())
continue; continue;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) { for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues; materializeHostConstantsInCore(coreOp, rewriter, hasFailure);
for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) { for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
if (isa<pim::PimHaltOp>(op)) materializeHostConstantsInCore(coreBatchOp, rewriter, hasFailure);
continue;
for (OpOperand& operand : op.getOpOperands()) {
Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber()))
continue;
auto resolvedAddress = resolveContiguousAddress(originalValue);
if (failed(resolvedAddress))
continue;
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
if (!getGlobalOp)
continue;
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
if (!originalType || !originalType.hasStaticShape()) {
op.emitOpError("host constant materialization requires a static memref operand");
hasFailure = true;
continue;
}
auto& cachedByOffset = materializedValues[resolvedAddress->base];
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
auto cachedValue = cachedByType.find(originalType);
if (cachedValue != cachedByType.end()) {
operand.set(cachedValue->second);
continue;
}
int64_t totalBytes = getValueSizeInBytes(originalValue);
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
hasFailure = true;
continue;
}
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
rewriter.setInsertionPoint(&op);
Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType);
Value deviceDst = localAlloc;
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(
rewriter,
op.getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
cachedByType[originalType] = hostToDevCopy.getResult();
operand.set(hostToDevCopy.getResult());
}
}
}
SmallVector<Operation*> hostCompactOps; SmallVector<Operation*> hostCompactOps;
for (Operation& op : funcOp.getBody().front()) for (Operation& op : funcOp.getBody().front())
@@ -122,6 +122,14 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
bool hasFailure = false; bool hasFailure = false;
moduleOp.walk([&](Operation* op) {
if (op->getDialect()->getNamespace() != "spat")
return;
op->emitError("illegal Spatial operation reached PIM codegen verification");
hasFailure = true;
});
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) { for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal()) if (funcOp.isExternal())
continue; continue;
+3 -5
View File
@@ -33,8 +33,6 @@ def _parse_pim_pass_timings(output_text):
pass_timings[label] = pass_timings.get(label, 0.0) + duration pass_timings[label] = pass_timings.get(label, 0.0) + duration
break break
if not pass_timings:
raise RuntimeError("Raptor timing report did not contain any PIM pass timings.")
return pass_timings return pass_timings
@@ -43,7 +41,7 @@ def _format_command(cmd):
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path, def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
crossbar_size, crossbar_count, core_count=None, cwd=None, reporter=None): crossbar_size, crossbar_count, core_count=None, cwd=None, verbose=False, reporter=None):
# Define the arguments, with the possibility to set crossbar size and count # Define the arguments, with the possibility to set crossbar size and count
args = [ args = [
network_path, network_path,
@@ -51,13 +49,13 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
output_base, output_base,
"--maccel=PIM", "--maccel=PIM",
"--EmitPimCodegen", "--EmitPimCodegen",
# "--use-experimental-conv-impl=true",
f"--crossbar-size={crossbar_size}", f"--crossbar-size={crossbar_size}",
f"--crossbar-count={crossbar_count}", f"--crossbar-count={crossbar_count}",
"--enable-timing",
] ]
if core_count is not None: if core_count is not None:
args.append(f"--core-count={core_count}") args.append(f"--core-count={core_count}")
if verbose:
args.append("--enable-timing")
cmd = [str(raptor_onnx_path)] + [str(arg) for arg in args] cmd = [str(raptor_onnx_path)] + [str(arg) for arg in args]
if reporter is not None: if reporter is not None:
+7 -5
View File
@@ -47,7 +47,9 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=
return_code = process.wait() return_code = process.wait()
if return_code != 0: if return_code != 0:
error_output = captured_output if not stream_output else recent_output error_output = captured_output if not stream_output else recent_output
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output)) exc = subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output))
exc.output_already_streamed = stream_output and bool(captured_output)
raise exc
return bytes(captured_output) return bytes(captured_output)
@@ -67,15 +69,15 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
stream_output = bool(getattr(reporter, "verbose", False)) stream_output = bool(getattr(reporter, "verbose", False))
if not stream_output: if not stream_output:
process = subprocess.Popen( completed = subprocess.run(
cmd, cmd,
cwd=cwd, cwd=cwd,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
) )
assert process.stdout is not None if completed.returncode != 0:
output = _stream_output(process.stdout.fileno(), process, reporter, stream_output=False) raise subprocess.CalledProcessError(completed.returncode, completed.args, output=completed.stdout)
return output.decode("utf-8", errors="replace") if capture_output else None return completed.stdout.decode("utf-8", errors="replace") if capture_output else None
try: try:
master_fd, slave_fd = pty.openpty() master_fd, slave_fd = pty.openpty()
+4 -1
View File
@@ -27,7 +27,9 @@ def print_validation_error(reporter, rel, exc):
file=sys.stderr, flush=True) file=sys.stderr, flush=True)
if isinstance(exc, subprocess.CalledProcessError): if isinstance(exc, subprocess.CalledProcessError):
print(format_return_status(exc.returncode), file=sys.stderr, flush=True) print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
if exc.output: if getattr(exc, "output_already_streamed", False):
print("Failure log already printed above.", file=sys.stderr, flush=True)
elif exc.output:
output_text = exc.output.decode("utf-8", errors="replace") if isinstance(exc.output, bytes) else str(exc.output) output_text = exc.output.decode("utf-8", errors="replace") if isinstance(exc.output, bytes) else str(exc.output)
if output_text: if output_text:
print(output_text, file=sys.stderr, end="" if output_text.endswith("\n") else "\n", flush=True) print(output_text, file=sys.stderr, end="" if output_text.endswith("\n") else "\n", flush=True)
@@ -160,6 +162,7 @@ def main():
Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
print(f"| {rel.ljust(path_width)} | {status} |") print(f"| {rel.ljust(path_width)} | {status} |")
print(separator) print(separator)
if a.verbose:
print_average_pim_pass_timings( print_average_pim_pass_timings(
pass_timing_sums, pass_timing_sums,
pass_timing_counts, pass_timing_counts,
+18 -13
View File
@@ -1,6 +1,5 @@
import json import json
import numpy as np import numpy as np
import subprocess
import shutil import shutil
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -11,7 +10,6 @@ from raptor import compile_with_raptor
from gen_network_runner import gen_network_runner from gen_network_runner import gen_network_runner
from subprocess_utils import run_command_with_reporter from subprocess_utils import run_command_with_reporter
STAGE_TITLES = ( STAGE_TITLES = (
"Compile ONNX", "Compile ONNX",
"Build Runner", "Build Runner",
@@ -48,10 +46,12 @@ class ProgressReporter:
self.verbose = verbose self.verbose = verbose
self.columns = shutil.get_terminal_size((100, 20)).columns self.columns = shutil.get_terminal_size((100, 20)).columns
self.suspended = False self.suspended = False
self.rendered_width = 0
def _clear(self): def _clear(self):
if self.enabled: if self.enabled:
sys.stdout.write("\033[2K\r") sys.stdout.write("\r" + (" " * self.rendered_width) + "\r")
sys.stdout.flush()
def _render(self): def _render(self):
if not self.enabled or self.suspended: if not self.enabled or self.suspended:
@@ -92,9 +92,12 @@ class ProgressReporter:
available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3) available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3)
label = label[:available_label_width] label = label[:available_label_width]
plain_line = prefix_text + model_counter + f" P:{self.passed_models} F:{self.failed_models}" + label
sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL) rendered_line = prefix + model_counter + counts + label + Style.RESET_ALL
padded_width = max(self.rendered_width, len(plain_line))
sys.stdout.write("\r" + rendered_line + (" " * max(0, padded_width - len(plain_line))))
sys.stdout.flush() sys.stdout.flush()
self.rendered_width = len(plain_line)
def log(self, message="", color=None): def log(self, message="", color=None):
if not self.verbose: if not self.verbose:
@@ -124,18 +127,19 @@ class ProgressReporter:
self._render() self._render()
def suspend(self): def suspend(self):
self.suspended = True if self.enabled:
self._clear() self._clear()
sys.stdout.flush() self.suspended = True
def resume(self): def resume(self):
self.suspended = False self.suspended = False
self._render()
def finish(self): def finish(self):
if self.enabled: if self.enabled:
self.suspended = True self.suspended = True
self._clear() self._clear()
sys.stdout.flush() self.rendered_width = 0
def run_command(cmd, cwd=None, reporter=None): def run_command(cmd, cwd=None, reporter=None):
@@ -212,7 +216,8 @@ def build_dump_ranges(config_path, outputs_descriptor):
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None): def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
run_command( run_command(
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--", ["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator",
"--",
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges], "-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
cwd=simulator_dir, cwd=simulator_dir,
reporter=reporter, reporter=reporter,
@@ -293,7 +298,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.advance() reporter.advance()
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner") print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", verbose=False) gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c",
verbose=False)
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter) runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter)
print_info(reporter, f"Runner built at {runner_path}") print_info(reporter, f"Runner built at {runner_path}")
reporter.advance() reporter.advance()
@@ -316,9 +322,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM") print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
pim_pass_timings = compile_with_raptor( pim_pass_timings = compile_with_raptor(
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, crossbar_size, crossbar_count,
crossbar_size, crossbar_count, core_count=core_count, core_count=core_count, cwd=raptor_dir, verbose=verbose, reporter=reporter)
cwd=raptor_dir, reporter=reporter)
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}") print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
reporter.advance() reporter.advance()