Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -72,9 +72,8 @@ inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
inline ParseResult parseCompressedIntegerEntries(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<IntT>& values) {
|
||||
inline ParseResult
|
||||
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
@@ -113,8 +112,8 @@ inline ParseResult parseCompressedIntegerEntries(OpAsmParser& parser,
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
if ((last - first) % step != 0) {
|
||||
return parser.emitError(
|
||||
parser.getCurrentLocation(), "range end must be reachable from start using the given step");
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"range end must be reachable from start using the given step");
|
||||
}
|
||||
|
||||
for (int64_t value = first; value <= last; value += step)
|
||||
@@ -140,9 +139,8 @@ inline ParseResult parseCompressedIntegerEntries(OpAsmParser& parser,
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
inline ParseResult parseCompressedIntegerSequence(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<IntT>& values) {
|
||||
inline ParseResult
|
||||
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
return parseCompressedIntegerEntries(parser, delimiter, values);
|
||||
@@ -166,9 +164,7 @@ inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, Prin
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
inline void printCompressedIntegerSequence(OpAsmPrinter& printer,
|
||||
ArrayRef<IntT> values,
|
||||
ListDelimiter delimiter) {
|
||||
inline void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef<IntT> values, ListDelimiter delimiter) {
|
||||
struct FlatCompression {
|
||||
enum class Kind {
|
||||
Single,
|
||||
@@ -388,9 +384,7 @@ inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, List
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser,
|
||||
SmallVectorImpl<Type>& types,
|
||||
bool allowEmpty) {
|
||||
inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
|
||||
Type firstType;
|
||||
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
|
||||
if (!firstTypeResult.has_value()) {
|
||||
@@ -422,10 +416,9 @@ inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedOperandEntryWithFirst(
|
||||
OpAsmParser& parser,
|
||||
OpAsmParser::UnresolvedOperand firstOperand,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
inline ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
|
||||
OpAsmParser::UnresolvedOperand firstOperand,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
OpAsmParser::UnresolvedOperand lastOperand;
|
||||
if (parser.parseOperand(lastOperand))
|
||||
@@ -447,9 +440,8 @@ inline ParseResult parseCompressedOperandEntryWithFirst(
|
||||
return success();
|
||||
}
|
||||
|
||||
inline ParseResult parseOneCompressedOperandEntry(
|
||||
OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
inline ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
OpAsmParser::UnresolvedOperand firstOperand;
|
||||
if (parser.parseOperand(firstOperand))
|
||||
return failure();
|
||||
@@ -474,9 +466,8 @@ inline ParseResult parseCompressedOperandList(OpAsmParser& parser,
|
||||
}
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedOperandSequence(
|
||||
OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
inline ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
@@ -485,9 +476,7 @@ inline ParseResult parseCompressedOperandSequence(
|
||||
return success();
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedTypeList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<Type>& types) {
|
||||
inline ParseResult parseCompressedTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
@@ -522,10 +511,7 @@ inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
|
||||
return true;
|
||||
}
|
||||
|
||||
inline void printValueTupleRun(OpAsmPrinter& printer,
|
||||
ValueRange values,
|
||||
size_t tupleSize,
|
||||
ListDelimiter delimiter) {
|
||||
inline void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
for (size_t index = 0; index < tupleSize; ++index) {
|
||||
@@ -538,10 +524,7 @@ inline void printValueTupleRun(OpAsmPrinter& printer,
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
inline void printTypeTupleRun(OpAsmPrinter& printer,
|
||||
TypeRange types,
|
||||
size_t tupleSize,
|
||||
ListDelimiter delimiter) {
|
||||
inline void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
for (size_t index = 0; index < tupleSize; ++index) {
|
||||
@@ -554,10 +537,9 @@ inline void printTypeTupleRun(OpAsmPrinter& printer,
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedOrTupleOperandList(
|
||||
OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
inline ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
@@ -604,9 +586,8 @@ inline ParseResult parseCompressedOrTupleOperandList(
|
||||
}
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<Type>& types) {
|
||||
inline ParseResult
|
||||
parseCompressedOrTupleTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
|
||||
@@ -15,7 +15,10 @@ add_pim_library(OMPimCompilerOptions
|
||||
|
||||
add_pim_library(OMPimCompilerUtils
|
||||
PimCompilerUtils.cpp
|
||||
PimArtifactWriter.cpp
|
||||
PimBatchEmission.cpp
|
||||
PimCodeGen.cpp
|
||||
PimWeightEmitter.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
OnnxMlirCompilerErrorCodes writeHostCoreJson(StringRef outputDirPath) {
|
||||
std::error_code errorCode;
|
||||
std::string outputHostCorePath = outputDirPath.str() + "/core_0.json";
|
||||
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
|
||||
// The host core json contains two no-op-like instructions to satisfy pimsim-nn.
|
||||
hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
|
||||
hostFileStream.close();
|
||||
return CompilerSuccess;
|
||||
}
|
||||
|
||||
OnnxMlirCompilerErrorCodes
|
||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
|
||||
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
||||
|
||||
SmallPtrSet<Operation*, 16> writtenGlobals;
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (hasWeightAlways(getGlobalOp))
|
||||
return;
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp)
|
||||
return;
|
||||
if (!writtenGlobals.insert(globalOp.getOperation()).second)
|
||||
return;
|
||||
auto initialValue = globalOp.getInitialValue();
|
||||
if (!initialValue)
|
||||
return;
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
|
||||
if (!denseAttr)
|
||||
return;
|
||||
|
||||
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
|
||||
ArrayRef<char> rawData = denseAttr.getRawData();
|
||||
char* dst = memoryBuffer.data() + memEntry.address;
|
||||
|
||||
if (denseAttr.isSplat()) {
|
||||
size_t elementSize = rawData.size();
|
||||
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
|
||||
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
|
||||
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
|
||||
}
|
||||
else {
|
||||
assert(rawData.size() == memEntry.size && "Data size mismatch");
|
||||
std::memcpy(dst, rawData.data(), rawData.size());
|
||||
}
|
||||
});
|
||||
|
||||
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
|
||||
memoryFileStream.close();
|
||||
return CompilerSuccess;
|
||||
}
|
||||
|
||||
OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
||||
PimAcceleratorMemory& memory,
|
||||
size_t maxCoreId,
|
||||
json::Object xbarsPerArrayGroup,
|
||||
StringRef outputDirPath) {
|
||||
json::Object configJson;
|
||||
|
||||
configJson["core_cnt"] = maxCoreId + 1;
|
||||
configJson["adc_count"] = 16;
|
||||
configJson["cell_precision"] = 2;
|
||||
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
|
||||
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
|
||||
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
|
||||
|
||||
json::Array inputsAddresses;
|
||||
for (BlockArgument input : funcOp.getArguments())
|
||||
inputsAddresses.push_back(memory.getValueAddress(input));
|
||||
configJson["inputs_addresses"] = std::move(inputsAddresses);
|
||||
|
||||
json::Array outputsAddresses;
|
||||
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
|
||||
for (mlir::Value output : returnOp.getOperands())
|
||||
outputsAddresses.push_back(memory.getValueAddress(output));
|
||||
configJson["outputs_addresses"] = std::move(outputsAddresses);
|
||||
|
||||
auto configPath = (outputDirPath + "/config.json").str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream jsonOS(configPath, errorCode);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening config file: " << errorCode.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
jsonOS << json::Value(std::move(configJson)) << '\n';
|
||||
jsonOS.close();
|
||||
return CompilerSuccess;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
|
||||
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
class PimAcceleratorMemory;
|
||||
|
||||
OnnxMlirCompilerErrorCodes writeHostCoreJson(llvm::StringRef outputDirPath);
|
||||
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
|
||||
mlir::func::FuncOp funcOp,
|
||||
PimAcceleratorMemory& memory,
|
||||
llvm::StringRef outputDirPath);
|
||||
OnnxMlirCompilerErrorCodes writeConfigJson(mlir::func::FuncOp funcOp,
|
||||
PimAcceleratorMemory& memory,
|
||||
size_t maxCoreId,
|
||||
llvm::json::Object xbarsPerArrayGroup,
|
||||
llvm::StringRef outputDirPath);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,126 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
||||
SmallVector<int32_t> laneCoreIds;
|
||||
laneCoreIds.reserve(coreIds.size() / laneCount);
|
||||
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
|
||||
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
|
||||
return laneCoreIds;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
|
||||
OpBuilder builder(scratchModule->getContext());
|
||||
builder.setInsertionPointToStart(scratchModule->getBody());
|
||||
|
||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
|
||||
SmallVector<Value> laneWeights;
|
||||
laneWeights.reserve(weightsPerLane);
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
||||
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
|
||||
|
||||
auto coreIds = getBatchCoreIds(coreBatchOp);
|
||||
auto scalarCore = pim::PimCoreOp::create(
|
||||
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
|
||||
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
|
||||
IRMapping mapper;
|
||||
if (coreBatchOp.getBody().front().getNumArguments() == 1)
|
||||
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
|
||||
|
||||
builder.setInsertionPointToEnd(block);
|
||||
for (Operation& op : coreBatchOp.getBody().front()) {
|
||||
if (isa<pim::PimHaltOp>(op)) {
|
||||
pim::PimHaltOp::create(builder, op.getLoc());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||
pim::PimSendOp::create(builder,
|
||||
sendBatchOp.getLoc(),
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
sendBatchOp.getSizeAttr(),
|
||||
builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
|
||||
pim::PimSendTensorOp::create(
|
||||
builder,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||
auto scalarReceive =
|
||||
pim::PimReceiveOp::create(builder,
|
||||
receiveBatchOp.getLoc(),
|
||||
receiveBatchOp.getOutput().getType(),
|
||||
mapper.lookup(receiveBatchOp.getOutputBuffer()),
|
||||
receiveBatchOp.getSizeAttr(),
|
||||
builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
|
||||
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
|
||||
auto scalarReceive = pim::PimReceiveTensorOp::create(
|
||||
builder,
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
receiveTensorBatchOp.getOutput().getType(),
|
||||
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
|
||||
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
|
||||
if (!hostSource)
|
||||
hostSource = memcpBatchOp.getHostSource();
|
||||
|
||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder,
|
||||
memcpBatchOp.getLoc(),
|
||||
memcpBatchOp.getOutput().getType(),
|
||||
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
||||
hostSource,
|
||||
memcpBatchOp.getDeviceTargetOffsetAttr(),
|
||||
memcpBatchOp.getHostSourceOffsetAttr(),
|
||||
memcpBatchOp.getSizeAttr());
|
||||
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* cloned = builder.clone(op, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
}
|
||||
|
||||
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
||||
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
||||
return callback(scalarCore);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+24
-376
@@ -5,12 +5,10 @@
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
@@ -21,7 +19,6 @@
|
||||
#include <absl/types/compare.h>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
@@ -29,8 +26,11 @@
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
@@ -42,79 +42,6 @@ static size_t getValueSizeInBytes(mlir::Value value) {
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
struct DenseWeightView {
|
||||
DenseElementsAttr denseAttr;
|
||||
SmallVector<int64_t> shape;
|
||||
SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
};
|
||||
|
||||
static SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
|
||||
strides[index] = strides[index + 1] * shape[index + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
static bool allStaticSubviewParts(memref::SubViewOp subview) {
|
||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
}
|
||||
|
||||
static FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||
SmallVector<memref::SubViewOp> subviews;
|
||||
mlir::Value current = weight;
|
||||
memref::GetGlobalOp getGlobalOp;
|
||||
|
||||
while (true) {
|
||||
Operation* defOp = current.getDefiningOp();
|
||||
if (!defOp)
|
||||
return failure();
|
||||
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
|
||||
break;
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!allStaticSubviewParts(subview))
|
||||
return failure();
|
||||
subviews.push_back(subview);
|
||||
current = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
|
||||
current = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
DenseWeightView view;
|
||||
view.denseAttr = denseAttr;
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStridesForShape(view.shape);
|
||||
|
||||
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getStaticStrides().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
view.offset += offset * sourceStride;
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
}
|
||||
|
||||
return view;
|
||||
}
|
||||
|
||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
assert("Only static shape is supported" && type.hasStaticShape());
|
||||
@@ -723,80 +650,6 @@ static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||
return coreLikeOps;
|
||||
}
|
||||
|
||||
static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
|
||||
OpBuilder builder(coreBatchOp);
|
||||
builder.setInsertionPointAfter(coreBatchOp);
|
||||
|
||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
|
||||
SmallVector<mlir::Value> laneWeights;
|
||||
laneWeights.reserve(weightsPerLane);
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
||||
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
|
||||
|
||||
auto coreIds = getBatchCoreIds(coreBatchOp);
|
||||
auto scalarCore = pim::PimCoreOp::create(
|
||||
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
|
||||
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
|
||||
IRMapping mapper;
|
||||
if (coreBatchOp.getBody().front().getNumArguments() == 1)
|
||||
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
|
||||
|
||||
builder.setInsertionPointToEnd(block);
|
||||
for (Operation& op : coreBatchOp.getBody().front()) {
|
||||
if (isa<pim::PimHaltOp>(op)) {
|
||||
pim::PimHaltOp::create(builder, op.getLoc());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||
pim::PimSendOp::create(builder,
|
||||
sendBatchOp.getLoc(),
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
sendBatchOp.getSizeAttr(),
|
||||
builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||
auto scalarReceive =
|
||||
pim::PimReceiveOp::create(builder,
|
||||
receiveBatchOp.getLoc(),
|
||||
receiveBatchOp.getOutput().getType(),
|
||||
mapper.lookup(receiveBatchOp.getOutputBuffer()),
|
||||
receiveBatchOp.getSizeAttr(),
|
||||
builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
|
||||
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
|
||||
if (!hostSource)
|
||||
hostSource = memcpBatchOp.getHostSource();
|
||||
|
||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder,
|
||||
memcpBatchOp.getLoc(),
|
||||
memcpBatchOp.getOutput().getType(),
|
||||
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
||||
hostSource,
|
||||
memcpBatchOp.getDeviceTargetOffsetAttr(),
|
||||
memcpBatchOp.getHostSourceOffsetAttr(),
|
||||
memcpBatchOp.getSizeAttr());
|
||||
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* cloned = builder.clone(op, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
}
|
||||
|
||||
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
||||
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
||||
return scalarCore;
|
||||
}
|
||||
|
||||
static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
||||
func::FuncOp funcOp,
|
||||
pim::PimCoreOp coreOp,
|
||||
@@ -822,56 +675,6 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
||||
});
|
||||
}
|
||||
|
||||
/// Write global constant data into a binary memory image at their allocated addresses.
|
||||
static OnnxMlirCompilerErrorCodes
|
||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
|
||||
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
||||
|
||||
SmallPtrSet<Operation*, 16> writtenGlobals;
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (hasWeightAlways(getGlobalOp))
|
||||
return;
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp)
|
||||
return;
|
||||
if (!writtenGlobals.insert(globalOp.getOperation()).second)
|
||||
return;
|
||||
auto initialValue = globalOp.getInitialValue();
|
||||
if (!initialValue)
|
||||
return;
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
|
||||
if (!denseAttr)
|
||||
return;
|
||||
|
||||
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
|
||||
ArrayRef<char> rawData = denseAttr.getRawData();
|
||||
char* dst = memoryBuffer.data() + memEntry.address;
|
||||
|
||||
if (denseAttr.isSplat()) {
|
||||
size_t elementSize = rawData.size();
|
||||
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
|
||||
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
|
||||
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
|
||||
}
|
||||
else {
|
||||
assert(rawData.size() == memEntry.size && "Data size mismatch");
|
||||
std::memcpy(dst, rawData.data(), rawData.size());
|
||||
}
|
||||
});
|
||||
|
||||
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
|
||||
memoryFileStream.close();
|
||||
return CompilerSuccess;
|
||||
}
|
||||
|
||||
/// Dispatch all operations in a core region to the appropriate code generator.
|
||||
/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is
|
||||
/// fully resolved before the JSON instructions are emitted.
|
||||
@@ -926,7 +729,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
|
||||
else {
|
||||
op.emitError("Unsupported codegen for this operation");
|
||||
op.dump();
|
||||
return failure();
|
||||
}
|
||||
processedOperations++;
|
||||
@@ -935,154 +737,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
|
||||
}
|
||||
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
|
||||
auto coreWeightsDirPath = outputDirPath + "/weights";
|
||||
auto error = sys::fs::create_directory(coreWeightsDirPath);
|
||||
assert(!error && "Error creating weights directory");
|
||||
size_t indexFileName = 0;
|
||||
|
||||
int64_t xbarSize = crossbarSize.getValue();
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
||||
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||
|
||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||
|
||||
for (Operation* op : coreLikeOps) {
|
||||
SmallVector<pim::PimCoreOp> scalarCores;
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
scalarCores.push_back(coreOp);
|
||||
}
|
||||
else {
|
||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
||||
scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane));
|
||||
}
|
||||
|
||||
for (pim::PimCoreOp coreOp : scalarCores) {
|
||||
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
|
||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||
if (index >= coreOp.getWeights().size()) {
|
||||
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||
}
|
||||
mlir::Value weight = coreOp.getWeights()[index];
|
||||
|
||||
auto weightView = resolveDenseWeightView(moduleOp, weight);
|
||||
if (failed(weightView)) {
|
||||
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
|
||||
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
|
||||
}
|
||||
|
||||
if (mapCoreWeightToFileName[coreId].contains(weight))
|
||||
continue;
|
||||
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
|
||||
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
|
||||
auto& fileName = mapGlobalOpToFileName[globalOp];
|
||||
mapCoreWeightToFileName[coreId].insert({weight, fileName});
|
||||
continue;
|
||||
}
|
||||
|
||||
DenseElementsAttr denseAttr = weightView->denseAttr;
|
||||
ArrayRef<int64_t> shape = weightView->shape;
|
||||
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
||||
int64_t numRows = shape[0];
|
||||
int64_t numCols = shape[1];
|
||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||
|
||||
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
|
||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
||||
assert(errorCode);
|
||||
}
|
||||
|
||||
uint64_t zero = 0;
|
||||
for (int64_t row = 0; row < xbarSize; row++) {
|
||||
for (int64_t col = 0; col < xbarSize; col++) {
|
||||
if (row < numRows && col < numCols) {
|
||||
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
|
||||
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
|
||||
uint64_t word = bits.getZExtValue();
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
||||
}
|
||||
else {
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
weightFileStream.close();
|
||||
if (globalOp)
|
||||
mapGlobalOpToFileName.insert({globalOp, newFileName});
|
||||
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
|
||||
}
|
||||
}
|
||||
|
||||
for (pim::PimCoreOp coreOp : scalarCores)
|
||||
if (coreOp.getOperation() != op)
|
||||
coreOp.erase();
|
||||
}
|
||||
return mapCoreWeightToFileName;
|
||||
}
|
||||
|
||||
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
|
||||
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
||||
PimAcceleratorMemory& memory,
|
||||
size_t maxCoreId,
|
||||
json::Object xbarsPerArrayGroup,
|
||||
StringRef outputDirPath) {
|
||||
json::Object configJson;
|
||||
|
||||
// pimsim-nn indexes cores directly by their numeric core ID, with the host
|
||||
// occupying core 0.
|
||||
configJson["core_cnt"] = maxCoreId + 1;
|
||||
|
||||
// TODO: Should this be based on the floating point type used in the model?
|
||||
// The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision
|
||||
|
||||
// Number of ADC for MVM units
|
||||
configJson["adc_count"] = 16;
|
||||
// The bit precision of each ADC
|
||||
configJson["cell_precision"] = 2;
|
||||
|
||||
// Crossbar configuration
|
||||
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
|
||||
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
|
||||
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
|
||||
|
||||
// Memory layout of inputs and outputs
|
||||
json::Array inputsAddresses;
|
||||
for (BlockArgument input : funcOp.getArguments())
|
||||
inputsAddresses.push_back(memory.getValueAddress(input));
|
||||
configJson["inputs_addresses"] = std::move(inputsAddresses);
|
||||
|
||||
json::Array outputsAddresses;
|
||||
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
|
||||
for (mlir::Value output : returnOp.getOperands())
|
||||
outputsAddresses.push_back(memory.getValueAddress(output));
|
||||
configJson["outputs_addresses"] = std::move(outputsAddresses);
|
||||
|
||||
auto configPath = (outputDirPath + "/config.json").str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream jsonOS(configPath, errorCode);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening config file: " << errorCode.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
jsonOS << json::Value(std::move(configJson)) << '\n';
|
||||
jsonOS.close();
|
||||
|
||||
return CompilerSuccess;
|
||||
}
|
||||
|
||||
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::string& outputDirPath) {
|
||||
if (!outputDirPath.empty()) {
|
||||
if (auto error = sys::fs::create_directory(outputDirPath)) {
|
||||
@@ -1103,17 +757,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
|
||||
return err;
|
||||
|
||||
// Write empty host core file
|
||||
std::error_code errorCode;
|
||||
auto outputHostCorePath = outputDirPath + "/core_0.json";
|
||||
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
// The host core json contains 2 random instructions, just to make pimsim-nn happy
|
||||
hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
|
||||
hostFileStream.close();
|
||||
if (auto err = writeHostCoreJson(outputDirPath))
|
||||
return err;
|
||||
|
||||
// For each core, specify the number of crossbar per array group.
|
||||
// This implementation always assigns one crossbar per group.
|
||||
@@ -1145,17 +790,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
}
|
||||
|
||||
for (Operation* op : coreLikeOps) {
|
||||
SmallVector<pim::PimCoreOp> scalarCores;
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
scalarCores.push_back(coreOp);
|
||||
}
|
||||
else {
|
||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
||||
scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane));
|
||||
}
|
||||
|
||||
for (pim::PimCoreOp coreOp : scalarCores) {
|
||||
auto emitCore = [&](pim::PimCoreOp coreOp, bool temporaryCore) -> OnnxMlirCompilerErrorCodes {
|
||||
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
|
||||
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||
maxCoreId = std::max(maxCoreId, coreId);
|
||||
@@ -1210,13 +845,26 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
}
|
||||
|
||||
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
|
||||
if (temporaryCore)
|
||||
coreOp.walk([&memory](Operation* op) { memory.clean(op); });
|
||||
return CompilerSuccess;
|
||||
};
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
if (auto err = emitCore(coreOp, false))
|
||||
return err;
|
||||
continue;
|
||||
}
|
||||
|
||||
for (pim::PimCoreOp coreOp : scalarCores)
|
||||
if (coreOp.getOperation() != op) {
|
||||
coreOp.walk([&memory](Operation* op) { memory.clean(op); });
|
||||
coreOp.erase();
|
||||
}
|
||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
||||
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
|
||||
laneResult = emitCore(coreOp, true);
|
||||
return laneResult == CompilerSuccess ? success() : failure();
|
||||
})))
|
||||
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
|
||||
}
|
||||
}
|
||||
|
||||
return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath);
|
||||
|
||||
@@ -0,0 +1,221 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct DenseWeightView {
|
||||
DenseElementsAttr denseAttr;
|
||||
SmallVector<int64_t> shape;
|
||||
SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
};
|
||||
|
||||
SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
|
||||
strides[index] = strides[index + 1] * shape[index + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
bool allStaticSubviewParts(memref::SubViewOp subview) {
|
||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
}
|
||||
|
||||
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||
SmallVector<memref::SubViewOp> subviews;
|
||||
mlir::Value current = weight;
|
||||
memref::GetGlobalOp getGlobalOp;
|
||||
|
||||
while (true) {
|
||||
Operation* defOp = current.getDefiningOp();
|
||||
if (!defOp)
|
||||
return failure();
|
||||
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
|
||||
break;
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!allStaticSubviewParts(subview))
|
||||
return failure();
|
||||
subviews.push_back(subview);
|
||||
current = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
|
||||
current = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
DenseWeightView view;
|
||||
view.denseAttr = denseAttr;
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStridesForShape(view.shape);
|
||||
|
||||
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getStaticStrides().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
view.offset += offset * sourceStride;
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
}
|
||||
|
||||
return view;
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
SmallVector<unsigned, 8> indices;
|
||||
auto addIndex = [&](unsigned weightIndex) {
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
||||
return getUsedWeightIndices(coreOp.getBody().front());
|
||||
}
|
||||
|
||||
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||
SmallVector<Operation*> coreLikeOps;
|
||||
for (Operation& op : funcOp.getBody().front())
|
||||
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
|
||||
coreLikeOps.push_back(&op);
|
||||
return coreLikeOps;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
|
||||
auto coreWeightsDirPath = outputDirPath + "/weights";
|
||||
auto error = sys::fs::create_directory(coreWeightsDirPath);
|
||||
assert(!error && "Error creating weights directory");
|
||||
size_t indexFileName = 0;
|
||||
|
||||
int64_t xbarSize = crossbarSize.getValue();
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
||||
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||
|
||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||
|
||||
for (Operation* op : coreLikeOps) {
|
||||
auto processCore = [&](pim::PimCoreOp coreOp) {
|
||||
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
|
||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||
if (index >= coreOp.getWeights().size()) {
|
||||
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||
}
|
||||
mlir::Value weight = coreOp.getWeights()[index];
|
||||
|
||||
auto weightView = resolveDenseWeightView(moduleOp, weight);
|
||||
if (failed(weightView)) {
|
||||
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
|
||||
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
|
||||
}
|
||||
|
||||
if (mapCoreWeightToFileName[coreId].contains(weight))
|
||||
continue;
|
||||
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
|
||||
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
|
||||
auto& fileName = mapGlobalOpToFileName[globalOp];
|
||||
mapCoreWeightToFileName[coreId].insert({weight, fileName});
|
||||
continue;
|
||||
}
|
||||
|
||||
DenseElementsAttr denseAttr = weightView->denseAttr;
|
||||
ArrayRef<int64_t> shape = weightView->shape;
|
||||
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
||||
int64_t numRows = shape[0];
|
||||
int64_t numCols = shape[1];
|
||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||
|
||||
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
|
||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
||||
assert(errorCode);
|
||||
}
|
||||
|
||||
uint64_t zero = 0;
|
||||
for (int64_t row = 0; row < xbarSize; row++) {
|
||||
for (int64_t col = 0; col < xbarSize; col++) {
|
||||
if (row < numRows && col < numCols) {
|
||||
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
|
||||
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
|
||||
uint64_t word = bits.getZExtValue();
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
||||
}
|
||||
else {
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
weightFileStream.close();
|
||||
if (globalOp)
|
||||
mapGlobalOpToFileName.insert({globalOp, newFileName});
|
||||
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
(void) processCore(coreOp);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
||||
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore)))
|
||||
return mapCoreWeightToFileName;
|
||||
}
|
||||
return mapCoreWeightToFileName;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||
createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -3,6 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
|
||||
add_pim_library(OMONNXToSpatial
|
||||
ConversionPatterns.cpp
|
||||
HostFoldability.cpp
|
||||
HostLegality.cpp
|
||||
PrePatterns.cpp
|
||||
PostPatterns.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/Elementwise.cpp
|
||||
Patterns/Math/Gemm.cpp
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#include "ComputeRegionBuilder.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -30,10 +32,29 @@ SmallVector<Value> sliceTensor(
|
||||
|
||||
for (int64_t i = 0; i < numSlices; i++) {
|
||||
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
|
||||
if (i == numSlices - 1 && lastSliceSize != 0)
|
||||
int64_t currentSliceSize = sliceSize;
|
||||
if (i == numSlices - 1 && lastSliceSize != 0) {
|
||||
currentSliceSize = lastSliceSize;
|
||||
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
||||
}
|
||||
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||
SmallVector<int64_t> sliceShape(shape.begin(), shape.end());
|
||||
sliceShape[axis] = currentSliceSize;
|
||||
auto sliceType =
|
||||
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
|
||||
|
||||
Value slice;
|
||||
if (isHostFoldableValue(tensorToSlice)) {
|
||||
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||
}
|
||||
else {
|
||||
auto sliceCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) {
|
||||
Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, computedSlice);
|
||||
});
|
||||
slice = sliceCompute.getResult(0);
|
||||
}
|
||||
slices.push_back(slice);
|
||||
}
|
||||
|
||||
|
||||
@@ -5,15 +5,15 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <class ShapedType>
|
||||
@@ -105,7 +105,8 @@ inline auto getTensorShape(mlir::Value tensor) {
|
||||
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
|
||||
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
|
||||
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape() && lhsType.getShape() == rhsType.getShape();
|
||||
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
|
||||
&& lhsType.getShape() == rhsType.getShape();
|
||||
}
|
||||
|
||||
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -28,7 +28,7 @@ bool isWeightLikeComputeOperand(Value value) {
|
||||
while (auto* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return false;
|
||||
if (hasWeightAlways(definingOp))
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
|
||||
return true;
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
||||
patterns.add<removeLRN>(ctx);
|
||||
|
||||
populateElementwisePatterns(patterns, ctx);
|
||||
populateGemmPatterns(patterns, ctx);
|
||||
populateConvPatterns(patterns, ctx);
|
||||
populatePoolPatterns(patterns, ctx);
|
||||
populateReduceMeanPatterns(patterns, ctx);
|
||||
populateReluPatterns(patterns, ctx);
|
||||
populateSigmoidPatterns(patterns, ctx);
|
||||
populateSoftmaxPatterns(patterns, ctx);
|
||||
populateConcatPatterns(patterns, ctx);
|
||||
populateGatherPatterns(patterns, ctx);
|
||||
populateResizePatterns(patterns, ctx);
|
||||
populateReshapePatterns(patterns, ctx);
|
||||
populateSplitPatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+2
@@ -5,6 +5,8 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
@@ -0,0 +1,75 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
||||
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
||||
}
|
||||
|
||||
static bool isStaticTensorResult(Operation* op) {
|
||||
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
return shapedType && shapedType.hasStaticShape();
|
||||
});
|
||||
}
|
||||
|
||||
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
if (!op || !visited.insert(op).second)
|
||||
return false;
|
||||
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||
return true;
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return false;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
||||
return isHostFoldableValue(transposeOp.getData());
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
||||
return isHostFoldableValue(collapseShapeOp.getSrc());
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
||||
return isHostFoldableValue(expandShapeOp.getSrc());
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||
return isHostFoldableValue(extractRowsOp.getInput());
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
|
||||
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isHostFoldableValue(Value value) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return isHostFoldableOpImpl(definingOp, visited);
|
||||
}
|
||||
|
||||
bool isHostFoldableOp(Operation* op) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return isHostFoldableOpImpl(op, visited);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isHostFoldableValue(mlir::Value value);
|
||||
|
||||
bool isHostFoldableOp(mlir::Operation* op);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,29 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
bool hasFailure = false;
|
||||
|
||||
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
continue;
|
||||
if (isHostFoldableOp(&op))
|
||||
continue;
|
||||
|
||||
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||
hasFailure = true;
|
||||
}
|
||||
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -8,21 +8,17 @@
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -33,8 +29,6 @@ namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
|
||||
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
|
||||
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
|
||||
@@ -44,71 +38,64 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
||||
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
||||
void populateEmptyFunction(func::FuncOp funcOp);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
||||
static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
SmallVector<spatial::SpatComputeBatch> batchOps;
|
||||
funcOp.walk([&](spatial::SpatComputeBatch batchOp) { batchOps.push_back(batchOp); });
|
||||
IRMapping mapper;
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
if (!computes.empty())
|
||||
return;
|
||||
|
||||
for (auto batchOp : batchOps) {
|
||||
if (batchOp.getLaneCount() != 1)
|
||||
continue;
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
|
||||
auto loc = batchOp.getLoc();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
computeOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLocs;
|
||||
sourceTypes.reserve(funcOp.getNumArguments());
|
||||
sourceLocs.reserve(funcOp.getNumArguments());
|
||||
for (Value source : funcOp.getArguments()) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLocs.push_back(source.getLoc());
|
||||
}
|
||||
|
||||
Block& templateBlock = batchOp.getBody().front();
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (BlockArgument arg : templateBlock.getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
auto newCompute = spatial::SpatCompute::create(
|
||||
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
||||
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
|
||||
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
|
||||
mapper.map(computeArg, blockArg);
|
||||
newCompute.getProperties().setOperandSegmentSizes({0, static_cast<int>(sourceTypes.size())});
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : templateBlock)
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : funcOp.getOps())
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
batchOp.replaceAllUsesWith(computeOp.getResults());
|
||||
rewriter.eraseOp(batchOp);
|
||||
}
|
||||
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
||||
for (size_t i = 0; i < yield.getNumOperands(); ++i)
|
||||
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
|
||||
op.dropAllUses();
|
||||
rewriter.eraseOp(&op);
|
||||
}
|
||||
|
||||
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
|
||||
returnOp.setOperand(index, computeResult);
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = &getContext();
|
||||
|
||||
RewritePatternSet mergeActivationPatterns(ctx);
|
||||
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
|
||||
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
|
||||
mergeActivationPatterns.add<matMulToGemm>(ctx);
|
||||
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
|
||||
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
||||
RewritePatternSet prePatterns(ctx);
|
||||
populatePrePatterns(prePatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
|
||||
llvm::dbgs() << "Failed to apply pre-patterns, continuing...\n";
|
||||
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
||||
|
||||
IRRewriter rewriter(moduleOp);
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
signalPassFailure();
|
||||
@@ -140,34 +127,23 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||
target.addIllegalOp<ONNXSplitOp>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<removeLRN>(ctx);
|
||||
|
||||
populateElementwisePatterns(patterns, ctx);
|
||||
populateGemmPatterns(patterns, ctx);
|
||||
populateConvPatterns(patterns, ctx);
|
||||
populatePoolPatterns(patterns, ctx);
|
||||
populateReduceMeanPatterns(patterns, ctx);
|
||||
populateReluPatterns(patterns, ctx);
|
||||
populateSigmoidPatterns(patterns, ctx);
|
||||
populateSoftmaxPatterns(patterns, ctx);
|
||||
populateConcatPatterns(patterns, ctx);
|
||||
populateGatherPatterns(patterns, ctx);
|
||||
populateResizePatterns(patterns, ctx);
|
||||
populateReshapePatterns(patterns, ctx);
|
||||
populateSplitPatterns(patterns, ctx);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
RewritePatternSet conversionPatterns(ctx);
|
||||
populateConversionPatterns(conversionPatterns, ctx);
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
foldSingleLaneComputeBatches(*entryFunc);
|
||||
RewritePatternSet earlyPostPatterns(ctx);
|
||||
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// Count the number of compute ops and check they do not exceed the core count
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
computeOpsCount++;
|
||||
|
||||
@@ -185,355 +161,23 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
RewritePatternSet postPatterns(ctx);
|
||||
populatePostPatterns(postPatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
populateEmptyFunction(*entryFunc);
|
||||
|
||||
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "spatial0");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
|
||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
||||
Value source = funcSource(toRemoveOp);
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
mapper.map(source, BB->getArgument(0));
|
||||
auto newInst = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
|
||||
auto source = toRemoveOp.getSource();
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
mapper.map(source, BB->getArgument(0));
|
||||
auto newInst = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
||||
auto sources = toRemoveOp.getInputs();
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (llvm::any_of(sources,
|
||||
[](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : sources) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(loc);
|
||||
}
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = spatial::SpatConcatOp::create(rewriter,
|
||||
loc,
|
||||
toRemoveOp.getType(),
|
||||
rewriter.getI64IntegerAttr(toRemoveOp.getDim()),
|
||||
ValueRange(BB->getArguments()));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat.getOutput());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : sources) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(loc);
|
||||
}
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static FailureOr<bool> sourceOperandHasWeightAlways(Operation* op) {
|
||||
if (op == nullptr)
|
||||
return false;
|
||||
|
||||
Operation* source = nullptr;
|
||||
do {
|
||||
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch>(*op)) {
|
||||
return false;
|
||||
}
|
||||
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(*op)) {
|
||||
auto tmpSource = extractSliceOp.getSource();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(*op)) {
|
||||
auto tmpSource = extractRowsOp.getInput();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(*op)) {
|
||||
auto tmpSource = expandShapeOp.getSrc();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto transposeOp = dyn_cast<ONNXTransposeOp>(*op)) {
|
||||
auto tmpSource = transposeOp.getData();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(*op)) {
|
||||
auto tmpSource = collapseShapeOp.getSrc();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto constantOp = dyn_cast<arith::ConstantOp>(*op)) {
|
||||
source = constantOp;
|
||||
}
|
||||
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(*op)) {
|
||||
bool res = false;
|
||||
for (auto operand : concatOp.getOperands()) {
|
||||
res |= hasWeightAlways(operand.getDefiningOp());
|
||||
if (res)
|
||||
return res;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
else if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(*op)) {
|
||||
bool res = false;
|
||||
for (auto operand : concatOp.getOperands()) {
|
||||
res |= hasWeightAlways(operand.getDefiningOp());
|
||||
if (res)
|
||||
return res;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
else {
|
||||
op->emitOpError("unsupported global instruction while promoting weight-backed operands into Spatial computes");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
while (source == nullptr);
|
||||
|
||||
return hasWeightAlways(source);
|
||||
}
|
||||
|
||||
// TODO what we want to keep in global?
|
||||
LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
Location loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
bool keep = true;
|
||||
while (keep) {
|
||||
keep = false;
|
||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatExtractRowsOp>(instruction)
|
||||
|| isa<func::ReturnOp>(instruction))
|
||||
continue;
|
||||
|
||||
auto weightBacked = sourceOperandHasWeightAlways(&instruction);
|
||||
if (failed(weightBacked))
|
||||
return failure();
|
||||
if (*weightBacked)
|
||||
continue;
|
||||
|
||||
keep |= encapsulateSlice(rewriter, loc, &instruction);
|
||||
|
||||
keep |= encapsulator<tensor::ExpandShapeOp>(
|
||||
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
||||
|
||||
keep |= encapsulator<ONNXTransposeOp>(
|
||||
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
|
||||
|
||||
keep |= encapsulator<tensor::CollapseShapeOp>(
|
||||
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
|
||||
|
||||
keep |= encapsulateConcat(rewriter, loc, &instruction);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
|
||||
for (auto compute : computes) {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
auto& oldBlock = compute.getBody().front();
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return compute.emitError("failed to materialize promoted weight-like operand inside compute body");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (auto& op : oldBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
compute.replaceAllUsesWith(newCompute);
|
||||
compute.erase();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::populateEmptyFunction(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(&getContext());
|
||||
IRMapping mapper;
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
if (!computes.empty())
|
||||
return;
|
||||
auto returnOp = llvm::cast<func::ReturnOp>(funcOp.getRegion().front().getTerminator());
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : funcOp.getArguments()) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(source.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute = spatial::SpatCompute::create(
|
||||
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||
for (auto [bbArg, computeArg] : llvm::zip(BB->getArguments(), newCompute.getOperands()))
|
||||
mapper.map(computeArg, bbArg);
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sourceTypes.size()});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
for (Operation& inst : funcOp.getOps())
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst))
|
||||
rewriter.clone(inst, mapper);
|
||||
|
||||
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
||||
for (size_t i = 0; i < yield.getNumOperands(); ++i)
|
||||
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
||||
|
||||
for (Operation& inst : llvm::make_early_inc_range(funcOp.getOps()))
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst)){
|
||||
inst.dropAllUses();
|
||||
rewriter.eraseOp(&inst);
|
||||
}
|
||||
|
||||
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
|
||||
returnOp.setOperand(index, computeResult);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -5,9 +5,9 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -49,6 +50,45 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
|
||||
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||
}
|
||||
|
||||
static Value transposeForSpatial(Value value,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (isHostFoldableValue(value))
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value
|
||||
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (isHostFoldableValue(value))
|
||||
return tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
resultType,
|
||||
value,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
resultType,
|
||||
input,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -81,6 +121,11 @@ static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
|
||||
|
||||
if (isHostFoldableValue(matrix)) {
|
||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
|
||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||
}
|
||||
|
||||
auto buildRowSlices = [&](Value matrixArg) {
|
||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
|
||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||
@@ -122,7 +167,8 @@ static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||
rootValue = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
return buildRowSlices(matrix);
|
||||
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
||||
return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -175,13 +221,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
||||
cType = expandedType;
|
||||
}
|
||||
if (!cType.hasStaticShape()) {
|
||||
@@ -196,25 +236,18 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||
SmallVector<Value> cSlices;
|
||||
if (hasC && cHasNumOutRows)
|
||||
cSlices = materializeBatchRowSlices(c, cType, rewriter, loc);
|
||||
|
||||
SmallVector<Value> gemvOps;
|
||||
gemvOps.reserve(numOutRows);
|
||||
gemvOps.reserve(static_cast<size_t>(numOutRows));
|
||||
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||
|
||||
Value cSlice = c;
|
||||
if (hasC) {
|
||||
if (cHasNumOutRows) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||
}
|
||||
if (cHasNumOutRows)
|
||||
cSlice = cSlices[static_cast<size_t>(rowIdx)];
|
||||
else if (!isVectorShape(getTensorShape(c))) {
|
||||
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
|
||||
return failure();
|
||||
@@ -224,7 +257,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
outRowType,
|
||||
aSlice,
|
||||
aSlices[static_cast<size_t>(rowIdx)],
|
||||
b,
|
||||
cSlice,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
@@ -267,13 +300,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
gemmLoc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
c = expandRankOneBias(c, expandedType, rewriter, gemmLoc);
|
||||
cType = expandedType;
|
||||
}
|
||||
if (!cType.hasStaticShape()) {
|
||||
@@ -305,13 +332,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
if (transA) {
|
||||
auto aShape = aType.getShape();
|
||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
||||
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType());
|
||||
a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc);
|
||||
aType = cast<RankedTensorType>(a.getType());
|
||||
}
|
||||
if (transB) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc);
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
|
||||
@@ -335,7 +363,6 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
||||
auto bNumVSlices = aNumHSlices;
|
||||
auto bLastVSliceSize = aLastHSliceSize;
|
||||
auto cNumHSlices = bNumHSlices;
|
||||
auto cLastHSliceSize = bLastHSliceSize;
|
||||
auto outNumHSlices = cNumHSlices;
|
||||
@@ -469,12 +496,15 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
(void) bType;
|
||||
|
||||
if (!isHostFoldableValue(b))
|
||||
return failure();
|
||||
|
||||
Value sharedBias;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
@@ -484,13 +514,7 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto cType = cast<RankedTensorType>(c.getType());
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
}
|
||||
if (!cType.hasStaticShape()) {
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -36,49 +36,27 @@ static Value extractBatchMatrix(Value value,
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides);
|
||||
|
||||
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
matrixType,
|
||||
slice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
auto buildMatrix = [&](Value input) -> Value {
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
matrixType,
|
||||
slice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
};
|
||||
|
||||
static bool isConstantLikeOperand(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
if (isHostFoldableValue(value))
|
||||
return buildMatrix(value);
|
||||
|
||||
while (auto* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return false;
|
||||
if (definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
return true;
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
value = extractSliceOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
value = expandShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
value = transposeOp.getData();
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
auto batchMatrixCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
|
||||
});
|
||||
return batchMatrixCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -107,15 +85,31 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -157,7 +151,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
}
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB());
|
||||
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||
|
||||
Value lhs = matmulOp.getA();
|
||||
Value rhs = matmulOp.getB();
|
||||
@@ -193,8 +187,14 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
if (useTransposedForm)
|
||||
gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}));
|
||||
if (useTransposedForm) {
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
gemmResult = transposeCompute.getResult(0);
|
||||
}
|
||||
rewriter.replaceOp(matmulOp, gemmResult);
|
||||
return success();
|
||||
}
|
||||
@@ -215,24 +215,30 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
if (useTransposedForm)
|
||||
gemmResult = ONNXTransposeOp::create(
|
||||
rewriter,
|
||||
loc,
|
||||
RankedTensorType::get({m, n}, outType.getElementType()),
|
||||
gemmResult,
|
||||
rewriter.getI64ArrayAttr({1, 0}));
|
||||
batchResults.push_back(tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
batchedOutType,
|
||||
gemmResult,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
}));
|
||||
auto batchResultCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
|
||||
Value resultMatrix = input;
|
||||
if (useTransposedForm) {
|
||||
resultMatrix = ONNXTransposeOp::create(rewriter,
|
||||
loc,
|
||||
RankedTensorType::get({m, n}, outType.getElementType()),
|
||||
input,
|
||||
rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
batchedOutType,
|
||||
resultMatrix,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
batchResults.push_back(batchResultCompute.getResult(0));
|
||||
}
|
||||
|
||||
Value result = createSpatConcat(rewriter, loc, /*axis=*/0, batchResults);
|
||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -81,6 +82,24 @@ createAverageCompute(Value input, RankedTensorType resultType, ConversionPattern
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value buildReduceMeanKeepdims(Value input,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
int64_t axis,
|
||||
@@ -100,7 +119,7 @@ static Value buildReduceMeanKeepdims(Value input,
|
||||
for (Value slice : slices)
|
||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
||||
|
||||
return createSpatConcat(rewriter, loc, axis, reducedSlices);
|
||||
return concatValues(reducedSlices, axis, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
@@ -115,9 +134,16 @@ static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
|
||||
}
|
||||
|
||||
return tensor::CollapseShapeOp::create(
|
||||
rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes))
|
||||
.getResult();
|
||||
auto reassociation = buildCollapseReassociation(reducedAxes);
|
||||
if (isHostFoldableValue(keepdimsValue))
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
|
||||
|
||||
auto squeezeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, ValueRange {keepdimsValue}, [&](Value input) {
|
||||
Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, collapsed);
|
||||
});
|
||||
return squeezeCompute.getResult(0);
|
||||
}
|
||||
|
||||
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
|
||||
@@ -31,8 +31,8 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
static FailureOr<Value>
|
||||
concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
|
||||
static FailureOr<Value> concatAlongAxis(
|
||||
ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
|
||||
if (values.empty()) {
|
||||
poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty");
|
||||
return failure();
|
||||
@@ -68,8 +68,8 @@ reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation*
|
||||
return reduced;
|
||||
}
|
||||
|
||||
static FailureOr<Value>
|
||||
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
|
||||
static FailureOr<Value> scaleAverageWindow(
|
||||
ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
|
||||
if (divisor <= 0) {
|
||||
op->emitOpError("AveragePool divisor must be positive");
|
||||
return failure();
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -32,6 +33,24 @@ static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewrit
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value
|
||||
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
@@ -47,7 +66,7 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe
|
||||
for (Value slice : slices)
|
||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
||||
|
||||
return createSpatConcat(rewriter, loc, axis, rebuiltSlices);
|
||||
return concatValues(rebuiltSlices, axis, rewriter, loc);
|
||||
}
|
||||
|
||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
@@ -92,8 +111,13 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
Value transposedInput = preTransposeCompute.getResult(0);
|
||||
Value transposedResult = buildSoftmax(
|
||||
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
result = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
|
||||
auto postTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
result = postTransposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(softmaxOp, result);
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -18,7 +20,17 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||
auto inputs = adaptor.getInputs();
|
||||
int64_t axis = adaptor.getAxis();
|
||||
|
||||
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
|
||||
if (llvm::all_of(inputs, isHostFoldableValue)) {
|
||||
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
|
||||
return success();
|
||||
}
|
||||
|
||||
auto computeOp = createSpatCompute(
|
||||
rewriter, maxpoolOp.getLoc(), TypeRange {maxpoolOp.getResult().getType()}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(
|
||||
rewriter, maxpoolOp.getLoc(), createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, args));
|
||||
});
|
||||
rewriter.replaceOp(maxpoolOp, computeOp.getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -95,18 +98,33 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
return success();
|
||||
}
|
||||
|
||||
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
|
||||
if (isHostFoldableValue(adaptor.getData())) {
|
||||
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
|
||||
return success();
|
||||
}
|
||||
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
|
||||
Value reshaped = buildReshape(data);
|
||||
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
|
||||
});
|
||||
rewriter.replaceOp(reshapeOp, computeOp.getResults());
|
||||
return success();
|
||||
};
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
if (sourceType.getRank() > resultType.getRank()
|
||||
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||
return success();
|
||||
}
|
||||
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
|
||||
return replaceWithReshape([&](Value data) {
|
||||
return tensor::CollapseShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||
});
|
||||
|
||||
if (sourceType.getRank() < resultType.getRank()
|
||||
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||
return success();
|
||||
}
|
||||
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
|
||||
return replaceWithReshape([&](Value data) {
|
||||
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||
});
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -47,16 +49,40 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||
outputs.reserve(splitOp.getNumResults());
|
||||
|
||||
int64_t offset = 0;
|
||||
SmallVector<RankedTensorType> resultTypes;
|
||||
resultTypes.reserve(splitOp.getNumResults());
|
||||
SmallVector<int64_t> sliceSizes;
|
||||
sliceSizes.reserve(splitOp.getNumResults());
|
||||
for (Value result : splitOp.getResults()) {
|
||||
auto resultType = dyn_cast<RankedTensorType>(result.getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
int64_t sliceSize = resultType.getShape()[axis];
|
||||
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
|
||||
offset += sliceSize;
|
||||
resultTypes.push_back(resultType);
|
||||
sliceSizes.push_back(resultType.getShape()[axis]);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(splitOp, outputs);
|
||||
if (isHostFoldableValue(adaptor.getInput())) {
|
||||
for (int64_t sliceSize : sliceSizes) {
|
||||
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
|
||||
offset += sliceSize;
|
||||
}
|
||||
rewriter.replaceOp(splitOp, outputs);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, splitOp.getLoc(), TypeRange(splitOp.getResultTypes()), {}, adaptor.getInput(), [&](Value input) {
|
||||
SmallVector<Value> runtimeOutputs;
|
||||
runtimeOutputs.reserve(resultTypes.size());
|
||||
int64_t runtimeOffset = 0;
|
||||
for (int64_t sliceSize : sliceSizes) {
|
||||
runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc()));
|
||||
runtimeOffset += sliceSize;
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs);
|
||||
});
|
||||
|
||||
rewriter.replaceOp(splitOp, computeOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -0,0 +1,265 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isWeightMaterializationHelperUser(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op);
|
||||
}
|
||||
|
||||
static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
||||
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
|
||||
}
|
||||
|
||||
static bool isDirectConstantValue(Value value) {
|
||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||
}
|
||||
|
||||
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
|
||||
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
|
||||
if (batchOp.getLaneCount() != 1)
|
||||
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
|
||||
|
||||
auto loc = batchOp.getLoc();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
computeOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
|
||||
Block& templateBlock = batchOp.getBody().front();
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(templateBlock.getNumArguments());
|
||||
blockArgLocs.reserve(templateBlock.getNumArguments());
|
||||
for (BlockArgument arg : templateBlock.getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
IRMapping mapper;
|
||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : templateBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
batchOp->replaceAllUsesWith(computeOp->getResults());
|
||||
rewriter.eraseOp(batchOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= oldBlock.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRRewriter bodyRewriter(rewriter.getContext());
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
|
||||
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= oldBlock.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatComputeBatch::create(rewriter,
|
||||
compute.getLoc(),
|
||||
compute.getResultTypes(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||
newWeights,
|
||||
newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRRewriter bodyRewriter(rewriter.getContext());
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
|
||||
}
|
||||
|
||||
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
||||
}
|
||||
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
|
||||
} // namespace
|
||||
|
||||
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
||||
patterns.add<onnxToArithConstant>(ctx);
|
||||
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
patterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
patterns.add<matMulAddToGemm>(ctx);
|
||||
patterns.add<matMulToGemm>(ctx);
|
||||
patterns.add<removeFlattenSameShape>(ctx);
|
||||
populateMatMulRewritePatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,224 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir::pim;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
|
||||
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
|
||||
SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
|
||||
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
|
||||
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++));
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
|
||||
Value input = mapper.lookup(sendTensorBatchOp.getInput());
|
||||
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
|
||||
if (concatOp.getDim() == 0)
|
||||
if (Value packedInput =
|
||||
createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorBatchOp.getLoc()))
|
||||
input = packedInput;
|
||||
|
||||
pim::PimSendTensorBatchOp::create(
|
||||
rewriter, sendTensorBatchOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
|
||||
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
|
||||
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
.getOutput();
|
||||
mapper.map(receiveTensorBatchOp.getOutput(), received);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||
if (computeBatchOp.getNumResults() != 0)
|
||||
return computeBatchOp.emitOpError(
|
||||
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
|
||||
|
||||
Location loc = computeBatchOp.getLoc();
|
||||
Block& oldBlock = computeBatchOp.getBody().front();
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
if (oldYield.getNumOperands() != 0)
|
||||
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
|
||||
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
|
||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||
SmallVector<Value> batchInputs;
|
||||
if (!computeBatchOp.getInputs().empty())
|
||||
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
|
||||
|
||||
rewriter.setInsertionPointAfter(computeBatchOp);
|
||||
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
|
||||
ValueRange(batchWeights),
|
||||
ValueRange(batchInputs));
|
||||
coreBatchOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (BlockArgument arg : oldBlock.getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(arg.getLoc());
|
||||
}
|
||||
Block* newBlock =
|
||||
rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
|
||||
IRMapping mapper;
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
|
||||
auto newArgType = cast<ShapedType>(newArg.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
newArg,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
getTensorSizeInBytesAttr(rewriter, newArg))
|
||||
.getOutput();
|
||||
mapper.map(oldArg, copied);
|
||||
}
|
||||
|
||||
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
|
||||
if (auto mapped = mapper.lookupOrNull(capturedTensor))
|
||||
return mapped;
|
||||
|
||||
auto capturedType = cast<ShapedType>(capturedTensor.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
|
||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
capturedTensor,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
getTensorSizeInBytesAttr(rewriter, capturedTensor))
|
||||
.getOutput();
|
||||
mapper.map(capturedTensor, copied);
|
||||
return copied;
|
||||
};
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : oldBlock) {
|
||||
if (isa<spatial::SpatYieldOp>(op))
|
||||
continue;
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
||||
pim::PimSendBatchOp::create(rewriter,
|
||||
loc,
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
|
||||
sendBatchOp.getTargetCoreIdsAttr());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
|
||||
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
||||
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
||||
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
|
||||
receiveBatchOp.getSourceCoreIdsAttr())
|
||||
.getOutput();
|
||||
mapper.map(receiveBatchOp.getOutput(), received);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
|
||||
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
||||
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
auto clonedTensor = cloned->getResult(0);
|
||||
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
clonedTensor,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
getTensorSizeInBytesAttr(rewriter, clonedTensor))
|
||||
.getOutput();
|
||||
mapper.map(toTensorOp.getResult(), copied);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (Value operand : op.getOperands()) {
|
||||
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = operand.getDefiningOp();
|
||||
if (definingOp && definingOp->getBlock() == &oldBlock)
|
||||
continue;
|
||||
|
||||
materializeCapturedTensor(operand);
|
||||
}
|
||||
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
PimHaltOp::create(rewriter, loc);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -4,8 +4,16 @@ add_public_tablegen_target(SpatialToPimIncGen)
|
||||
|
||||
add_pim_library(OMSpatialToPim
|
||||
SpatialToPimPass.cpp
|
||||
BatchCoreLoweringPatterns.cpp
|
||||
ChannelLoweringPatterns.cpp
|
||||
Cleanup.cpp
|
||||
Common.cpp
|
||||
Patterns.cpp
|
||||
ComputeLikeRegionUtils.cpp
|
||||
CoreLoweringPatterns.cpp
|
||||
GlobalTensorMaterialization.cpp
|
||||
PhaseVerification.cpp
|
||||
ReturnPathNormalization.cpp
|
||||
TensorPackingPatterns.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
|
||||
|
||||
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
|
||||
pim::PimSendOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
op.getInput(),
|
||||
getTensorSizeInBytesAttr(rewriter, op.getInput()),
|
||||
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveOp op, PatternRewriter& rewriter) const override {
|
||||
if (op->use_empty()) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
auto outputType = cast<ShapedType>(op.getResult().getType());
|
||||
Value outputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||
Value received = pim::PimReceiveOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
op.getResult().getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, op.getResult()),
|
||||
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(op.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : op.getTargetCoreIds())
|
||||
targetCoreIds.push_back(toPimCoreId(targetCoreId));
|
||||
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(op.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : op.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
|
||||
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
||||
Value outputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||
Value received =
|
||||
pim::PimReceiveTensorOp::create(
|
||||
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatExtractRowsOp op, PatternRewriter& rewriter) const override {
|
||||
auto inputType = cast<RankedTensorType>(op.getInput().getType());
|
||||
SmallVector<Value> replacements;
|
||||
replacements.reserve(op.getNumResults());
|
||||
for (auto [rowIndex, output] : llvm::enumerate(op.getOutputs())) {
|
||||
auto outputType = cast<RankedTensorType>(output.getType());
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(inputType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
replacements.push_back(
|
||||
tensor::ExtractSliceOp::create(rewriter, op.getLoc(), outputType, op.getInput(), offsets, sizes, strides)
|
||||
.getResult());
|
||||
}
|
||||
rewriter.replaceOp(op, replacements);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatConcatOp op, PatternRewriter& rewriter) const override {
|
||||
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
||||
Value outputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||
Value concatenated =
|
||||
pim::PimConcatOp::create(
|
||||
rewriter, op.getLoc(), op.getOutput().getType(), op.getAxisAttr(), op.getInputs(), outputBuffer)
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, concatenated);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<ChannelSendLowering,
|
||||
ChannelReceiveLowering,
|
||||
ChannelSendTensorLowering,
|
||||
ChannelReceiveTensorLowering,
|
||||
ExtractRowsLowering,
|
||||
ConcatLowering>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,42 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
|
||||
while (!pendingOps.empty()) {
|
||||
bool erasedAnyOp = false;
|
||||
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
|
||||
Operation* opToRemove = *it;
|
||||
if (!opToRemove->use_empty()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
rewriter.eraseOp(opToRemove);
|
||||
it = pendingOps.erase(it);
|
||||
erasedAnyOp = true;
|
||||
}
|
||||
|
||||
if (erasedAnyOp)
|
||||
continue;
|
||||
|
||||
for (Operation* opToRemove : pendingOps) {
|
||||
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
|
||||
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
|
||||
for (Operation* user : opToRemove->getUsers()) {
|
||||
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
|
||||
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
|
||||
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,44 @@
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigned operandNumber) {
|
||||
auto getInputIndex = [operandNumber](Operation* op, unsigned inputCount) -> std::optional<unsigned> {
|
||||
if (inputCount == 0)
|
||||
return std::nullopt;
|
||||
unsigned inputBegin = op->getNumOperands() - inputCount;
|
||||
if (operandNumber < inputBegin)
|
||||
return std::nullopt;
|
||||
return operandNumber - inputBegin;
|
||||
};
|
||||
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
|
||||
return getInputIndex(owner, compute.getInputs().size());
|
||||
|
||||
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner))
|
||||
return getInputIndex(owner, computeBatch.getInputs().size());
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||
Operation* owner,
|
||||
unsigned inputIndex,
|
||||
Value replacement) {
|
||||
Block& body = owner->getRegion(0).front();
|
||||
BlockArgument bodyArgument = body.getArgument(inputIndex);
|
||||
|
||||
rewriter.startOpModification(owner);
|
||||
bodyArgument.replaceAllUsesWith(replacement);
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
|
||||
compute.getInputsMutable().erase(inputIndex);
|
||||
else
|
||||
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
||||
body.eraseArgument(inputIndex);
|
||||
rewriter.finalizeOpModification(owner);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::optional<unsigned> getDirectComputeLikeInputIndex(mlir::Operation* owner, unsigned operandNumber);
|
||||
|
||||
void replaceAndEraseDirectComputeLikeInput(mlir::PatternRewriter& rewriter,
|
||||
mlir::Operation* owner,
|
||||
unsigned inputIndex,
|
||||
mlir::Value replacement);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,213 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir::pim;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool isChannelUseChainOp(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp,
|
||||
tensor::CollapseShapeOp,
|
||||
tensor::ExpandShapeOp,
|
||||
tensor::CastOp,
|
||||
tosa::ReshapeOp,
|
||||
ONNXTransposeOp,
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (mapping.lookupOrNull(operand))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = operand.getDefiningOp();
|
||||
if (!definingOp)
|
||||
continue;
|
||||
|
||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
}
|
||||
}
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
|
||||
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
||||
return static_cast<int32_t>(fallbackCoreId++);
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
SmallVectorImpl<Operation*>& helperChain,
|
||||
bool requireReturnUse = true) {
|
||||
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||
return failure();
|
||||
if (requireReturnUse
|
||||
&& (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin())))
|
||||
return failure();
|
||||
|
||||
Block& block = computeOp.getBody().front();
|
||||
if (block.getNumArguments() != 1)
|
||||
return failure();
|
||||
|
||||
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||
return failure();
|
||||
|
||||
SmallVector<Operation*> reverseChain;
|
||||
Value currentValue = yieldOp.getOperands().front();
|
||||
Value blockArg = block.getArgument(0);
|
||||
|
||||
while (currentValue != blockArg) {
|
||||
Operation* definingOp = currentValue.getDefiningOp();
|
||||
if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp))
|
||||
return failure();
|
||||
reverseChain.push_back(definingOp);
|
||||
currentValue = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
|
||||
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
|
||||
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
||||
return failure();
|
||||
|
||||
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||
return false;
|
||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
||||
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||
}))
|
||||
return false;
|
||||
|
||||
Block& block = computeOp.getBody().front();
|
||||
if (block.getNumArguments() != 0)
|
||||
return false;
|
||||
|
||||
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||
return false;
|
||||
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
IRMapping mapping;
|
||||
for (Operation& op : block.without_terminator()) {
|
||||
cloneMappedHelperOperands(&op, mapping, rewriter);
|
||||
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
}
|
||||
|
||||
Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
||||
computeOp.getResult(0).replaceAllUsesWith(replacement);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void markOpToRemove(CoreLoweringState& state, Operation* op) {
|
||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
||||
state.operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
|
||||
return success();
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
if (succeeded(collectHelperComputeChain(computeOp, helperChain)))
|
||||
return success();
|
||||
|
||||
auto& block = computeOp.getRegion().front();
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
|
||||
if (!receiveOp || blockArg.use_empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
|
||||
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
||||
Value received = PimReceiveOp::create(
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
markOpToRemove(state, receiveOp);
|
||||
}
|
||||
|
||||
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
||||
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
|
||||
|
||||
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
|
||||
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
|
||||
ReturnPathLoweringResult returnPathResult =
|
||||
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
|
||||
if (returnPathResult == ReturnPathLoweringResult::Failure)
|
||||
return failure();
|
||||
if (returnPathResult == ReturnPathLoweringResult::Handled)
|
||||
continue;
|
||||
|
||||
auto resultUses = result.getUses();
|
||||
if (rangeLength(resultUses) == 1) {
|
||||
OpOperand& resultUse = *resultUses.begin();
|
||||
Operation* resultUser = resultUse.getOwner();
|
||||
if (isa<spatial::SpatChannelSendOp>(resultUser))
|
||||
continue;
|
||||
}
|
||||
|
||||
return computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering");
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
|
||||
|
||||
SmallVector<Value> computeWeights;
|
||||
if (!computeOp.getWeights().empty())
|
||||
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
auto coreOp = PimCoreOp::create(rewriter,
|
||||
loc,
|
||||
ValueRange(computeWeights),
|
||||
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
|
||||
if (!blockArg.use_empty())
|
||||
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
|
||||
block.eraseArguments(0, block.getNumArguments());
|
||||
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
||||
Block* tempComputeBlock = new Block();
|
||||
computeOp.getBody().push_back(tempComputeBlock);
|
||||
rewriter.setInsertionPointToEnd(tempComputeBlock);
|
||||
PimHaltOp::create(rewriter, computeOp.getLoc());
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct CoreLoweringState {
|
||||
size_t& nextCoreId;
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
||||
};
|
||||
|
||||
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+70
-104
@@ -6,16 +6,17 @@
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -23,33 +24,33 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
|
||||
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) {
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||
unsigned inputCount = compute.getInputs().size();
|
||||
if (inputCount == 0)
|
||||
return std::nullopt;
|
||||
|
||||
unsigned inputBegin = compute->getNumOperands() - inputCount;
|
||||
if (operandNumber < inputBegin)
|
||||
return std::nullopt;
|
||||
return operandNumber - inputBegin;
|
||||
}
|
||||
|
||||
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner)) {
|
||||
unsigned inputCount = computeBatch.getInputs().size();
|
||||
if (inputCount == 0)
|
||||
return std::nullopt;
|
||||
|
||||
unsigned inputBegin = computeBatch->getNumOperands() - inputCount;
|
||||
if (operandNumber < inputBegin)
|
||||
return std::nullopt;
|
||||
return operandNumber - inputBegin;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
||||
std::string name = baseName.str();
|
||||
unsigned suffix = 0;
|
||||
while (SymbolTable::lookupSymbolIn(symbolTableOp, name))
|
||||
name = (baseName + "_" + Twine(suffix++)).str();
|
||||
return name;
|
||||
}
|
||||
|
||||
static memref::GlobalOp createPrivateMemrefGlobalWithUniqueName(PatternRewriter& rewriter,
|
||||
Location loc,
|
||||
ModuleOp moduleOp,
|
||||
StringRef baseName,
|
||||
MemRefType type,
|
||||
Attribute initialValue = {},
|
||||
UnitAttr constant = {}) {
|
||||
std::string symbolName = makeUniqueSymbolName(moduleOp, baseName);
|
||||
return memref::GlobalOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getStringAttr(symbolName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(type),
|
||||
initialValue,
|
||||
constant,
|
||||
IntegerAttr {});
|
||||
}
|
||||
|
||||
// Sinks top-level tensor slices into compute regions so later lowering sees local runtime work.
|
||||
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -59,7 +60,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
|
||||
for (auto& uses : extractSliceOp->getUses()) {
|
||||
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
||||
if (!getDirectComputeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
||||
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
||||
return failure();
|
||||
}
|
||||
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
|
||||
@@ -72,7 +73,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
||||
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
||||
auto inputIndex = getDirectComputeInputIndex(spatCompute, uses.getOperandNumber());
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
@@ -87,14 +88,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatCompute.getOperation());
|
||||
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute.getOperation()]);
|
||||
spatCompute.getInputsMutable().erase(BBArgIndex);
|
||||
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
@@ -109,11 +107,8 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
||||
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||
}
|
||||
else {
|
||||
{
|
||||
@@ -148,11 +143,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
};
|
||||
|
||||
// Turns runtime constants consumed by compute regions into private globals and local loads.
|
||||
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
|
||||
static int i = 0;
|
||||
Location loc = constantOp.getLoc();
|
||||
|
||||
if (hasWeightAlways(constantOp))
|
||||
@@ -177,15 +172,14 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
if (constRankedTensorType) {
|
||||
mlir::MemRefType memRefType =
|
||||
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
|
||||
std::string argName = "const_" + std::to_string(i++);
|
||||
memref::GlobalOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getStringAttr(argName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(memRefType),
|
||||
constantOp.getValueAttr(),
|
||||
rewriter.getUnitAttr(),
|
||||
{});
|
||||
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
|
||||
loc,
|
||||
constantOp->getParentOfType<ModuleOp>(),
|
||||
"const",
|
||||
memRefType,
|
||||
constantOp.getValueAttr(),
|
||||
rewriter.getUnitAttr());
|
||||
std::string argName = globalOp.getSymName().str();
|
||||
|
||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||
|
||||
@@ -193,11 +187,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
@@ -206,18 +199,14 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatCompute.getOperation());
|
||||
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||
spatCompute.getInputsMutable().erase(BBArgIndex);
|
||||
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
@@ -226,11 +215,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
||||
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter,
|
||||
spatComputeBatch.getOperation(),
|
||||
BBArgIndex,
|
||||
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||
}
|
||||
else {
|
||||
{
|
||||
@@ -272,34 +260,26 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
|
||||
rewriter.startOpModification(spatCompute.getOperation());
|
||||
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
|
||||
spatCompute.getInputsMutable().erase(BBArgIndex);
|
||||
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
|
||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
|
||||
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
||||
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
|
||||
}
|
||||
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
if (!mapSpatComputeToConst.contains(parent)) {
|
||||
@@ -321,11 +301,13 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
}
|
||||
}
|
||||
}
|
||||
rewriter.eraseOp(constantOp);
|
||||
if (constantOp->use_empty())
|
||||
rewriter.eraseOp(constantOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
|
||||
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -352,52 +334,36 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
mlir::MemRefType memRefType =
|
||||
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
|
||||
|
||||
std::string argName = "arg_" + std::to_string(index);
|
||||
|
||||
memref::GlobalOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getStringAttr(argName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(memRefType),
|
||||
{},
|
||||
{},
|
||||
{});
|
||||
std::string baseName = ("arg_" + Twine(index)).str();
|
||||
auto globalOp = createPrivateMemrefGlobalWithUniqueName(
|
||||
rewriter, loc, funcOp->getParentOfType<ModuleOp>(), baseName, memRefType);
|
||||
std::string argName = globalOp.getSymName().str();
|
||||
|
||||
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
||||
auto argUser = argUses.getOwner();
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
||||
auto inputIndex = getDirectComputeInputIndex(spatCompute, argUses.getOperandNumber());
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
|
||||
rewriter.startOpModification(spatCompute.getOperation());
|
||||
BBArgValue.replaceAllUsesWith(toTensor);
|
||||
spatCompute.getInputsMutable().erase(BBArgIndex);
|
||||
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
||||
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
|
||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||
BBArgValue.replaceAllUsesWith(toTensor);
|
||||
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
||||
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, toTensor);
|
||||
}
|
||||
else {
|
||||
rewriter.setInsertionPoint(argUser);
|
||||
@@ -416,7 +382,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
};
|
||||
|
||||
} // namespace
|
||||
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
|
||||
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult verifySpatialToPimBoundary(ModuleOp moduleOp) {
|
||||
bool hasFailure = false;
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
if (op->getDialect()->getNamespace() != "spat")
|
||||
return;
|
||||
|
||||
op->emitError("illegal Spatial operation remains after Spatial-to-PIM lowering");
|
||||
hasFailure = true;
|
||||
});
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult verifySpatialToPimBoundary(mlir::ModuleOp moduleOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,587 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir::pim;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct ReturnUseInfo {
|
||||
size_t returnIndex;
|
||||
SmallVector<Operation*> helperChain;
|
||||
};
|
||||
|
||||
struct ConcatReturnUseInfo {
|
||||
size_t returnIndex;
|
||||
SmallVector<int64_t> sliceOffsets;
|
||||
SmallVector<int64_t> concatShape;
|
||||
SmallVector<Operation*> concatChain;
|
||||
SmallVector<Operation*> helperChain;
|
||||
};
|
||||
|
||||
static bool isReturnHelperChainOp(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp,
|
||||
tensor::CollapseShapeOp,
|
||||
tensor::ExpandShapeOp,
|
||||
tensor::CastOp,
|
||||
tosa::ReshapeOp,
|
||||
ONNXTransposeOp,
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void markOpToRemove(ReturnPathState& state, Operation* op) {
|
||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
||||
state.operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
||||
std::string name = baseName.str();
|
||||
unsigned suffix = 0;
|
||||
while (SymbolTable::lookupSymbolIn(symbolTableOp, name))
|
||||
name = (baseName + "_" + Twine(suffix++)).str();
|
||||
return name;
|
||||
}
|
||||
|
||||
static int64_t computeFlatElementIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> shape) {
|
||||
int64_t flatIndex = 0;
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
flatIndex *= shape[i];
|
||||
flatIndex += indices[i];
|
||||
}
|
||||
return flatIndex;
|
||||
}
|
||||
|
||||
static SmallVector<int64_t> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> indices(shape.size(), 0);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
|
||||
indices[dim] = flatIndex % shape[dim];
|
||||
flatIndex /= shape[dim];
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
SmallVectorImpl<Operation*>& helperChain) {
|
||||
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||
return failure();
|
||||
if (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin()))
|
||||
return failure();
|
||||
|
||||
Block& block = computeOp.getBody().front();
|
||||
if (block.getNumArguments() != 1)
|
||||
return failure();
|
||||
|
||||
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||
return failure();
|
||||
|
||||
SmallVector<Operation*> reverseChain;
|
||||
Value currentValue = yieldOp.getOperands().front();
|
||||
Value blockArg = block.getArgument(0);
|
||||
|
||||
while (currentValue != blockArg) {
|
||||
Operation* definingOp = currentValue.getDefiningOp();
|
||||
if (!definingOp || definingOp->getBlock() != &block || !isReturnHelperChainOp(definingOp))
|
||||
return failure();
|
||||
reverseChain.push_back(definingOp);
|
||||
currentValue = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
|
||||
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
|
||||
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
||||
return failure();
|
||||
|
||||
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
|
||||
return success();
|
||||
}
|
||||
|
||||
static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
|
||||
auto uses = value.getUses();
|
||||
if (rangeLength(uses) != 1)
|
||||
return std::nullopt;
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
Value currentValue = value;
|
||||
Operation* currentUser = uses.begin()->getOwner();
|
||||
|
||||
while (isReturnHelperChainOp(currentUser)) {
|
||||
helperChain.push_back(currentUser);
|
||||
auto currentUses = currentUser->getResult(0).getUses();
|
||||
if (rangeLength(currentUses) != 1)
|
||||
return std::nullopt;
|
||||
currentValue = currentUser->getResult(0);
|
||||
currentUser = currentUses.begin()->getOwner();
|
||||
}
|
||||
|
||||
if (!isa<func::ReturnOp>(currentUser))
|
||||
return std::nullopt;
|
||||
|
||||
return ReturnUseInfo {
|
||||
currentValue.getUses().begin()->getOperandNumber(),
|
||||
std::move(helperChain),
|
||||
};
|
||||
}
|
||||
|
||||
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
auto getConcatResult = [](Operation* op) -> Value {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getResult();
|
||||
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
|
||||
return spatialConcat.getOutput();
|
||||
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||
return pimConcat.getOutput();
|
||||
return {};
|
||||
};
|
||||
auto getConcatAxis = [](Operation* op) -> std::optional<int64_t> {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getDim();
|
||||
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
|
||||
return spatialConcat.getAxis();
|
||||
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||
return pimConcat.getAxis();
|
||||
return std::nullopt;
|
||||
};
|
||||
auto getConcatOperands = [](Operation* op) -> OperandRange {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getOperands();
|
||||
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
|
||||
return spatialConcat.getInputs();
|
||||
return cast<pim::PimConcatOp>(op).getInputs();
|
||||
};
|
||||
|
||||
auto uses = value.getUses();
|
||||
if (rangeLength(uses) != 1
|
||||
|| !isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
|
||||
return std::nullopt;
|
||||
|
||||
auto valueType = dyn_cast<ShapedType>(value.getType());
|
||||
if (!valueType || !valueType.hasStaticShape())
|
||||
return std::nullopt;
|
||||
|
||||
SmallVector<int64_t> sliceOffsets(valueType.getRank(), 0);
|
||||
SmallVector<int64_t> concatShape(valueType.getShape().begin(), valueType.getShape().end());
|
||||
SmallVector<Operation*> concatChain;
|
||||
Value currentValue = value;
|
||||
Operation* currentUser = uses.begin()->getOwner();
|
||||
|
||||
while (isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(currentUser)) {
|
||||
concatChain.push_back(currentUser);
|
||||
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
|
||||
int64_t axis = *getConcatAxis(currentUser);
|
||||
for (Value operand : getConcatOperands(currentUser).take_front(operandIndex))
|
||||
sliceOffsets[axis] += cast<ShapedType>(operand.getType()).getShape()[axis];
|
||||
|
||||
Value concatResult = getConcatResult(currentUser);
|
||||
auto concatType = dyn_cast<ShapedType>(concatResult.getType());
|
||||
if (!concatType || !concatType.hasStaticShape())
|
||||
return std::nullopt;
|
||||
concatShape.assign(concatType.getShape().begin(), concatType.getShape().end());
|
||||
|
||||
currentValue = concatResult;
|
||||
auto currentUses = currentValue.getUses();
|
||||
if (rangeLength(currentUses) != 1)
|
||||
return std::nullopt;
|
||||
currentUser = currentUses.begin()->getOwner();
|
||||
}
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
|
||||
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
|
||||
return std::nullopt;
|
||||
|
||||
if (failed(collectHelperComputeChain(helperCompute, helperChain)))
|
||||
return std::nullopt;
|
||||
|
||||
currentValue = helperCompute.getResult(0);
|
||||
auto currentUses = currentValue.getUses();
|
||||
if (rangeLength(currentUses) != 1)
|
||||
return std::nullopt;
|
||||
currentUser = currentUses.begin()->getOwner();
|
||||
}
|
||||
|
||||
while (isReturnHelperChainOp(currentUser)) {
|
||||
helperChain.push_back(currentUser);
|
||||
auto currentUses = currentUser->getResult(0).getUses();
|
||||
if (rangeLength(currentUses) != 1)
|
||||
return std::nullopt;
|
||||
currentValue = currentUser->getResult(0);
|
||||
currentUser = currentUses.begin()->getOwner();
|
||||
}
|
||||
|
||||
if (!isa<func::ReturnOp>(currentUser))
|
||||
return std::nullopt;
|
||||
|
||||
return ConcatReturnUseInfo {
|
||||
currentValue.getUses().begin()->getOperandNumber(),
|
||||
std::move(sliceOffsets),
|
||||
std::move(concatShape),
|
||||
std::move(concatChain),
|
||||
std::move(helperChain),
|
||||
};
|
||||
}
|
||||
|
||||
static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndices,
|
||||
ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<Operation*> helperChain,
|
||||
SmallVectorImpl<int64_t>& mappedIndices) {
|
||||
SmallVector<int64_t> currentIndices(sourceIndices.begin(), sourceIndices.end());
|
||||
SmallVector<int64_t> currentShape(sourceShape.begin(), sourceShape.end());
|
||||
|
||||
auto reshapeToResultShape = [&](Operation* op) -> LogicalResult {
|
||||
auto resultType = dyn_cast<ShapedType>(op->getResult(0).getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape);
|
||||
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
currentIndices = expandFlatElementIndex(flatIndex, currentShape);
|
||||
return success();
|
||||
};
|
||||
|
||||
for (Operation* op : helperChain) {
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
||||
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
};
|
||||
if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(extractSliceOp.getStaticStrides()))
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> nextIndices;
|
||||
nextIndices.reserve(currentIndices.size());
|
||||
for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices,
|
||||
extractSliceOp.getStaticOffsets(),
|
||||
extractSliceOp.getStaticSizes(),
|
||||
extractSliceOp.getStaticStrides())) {
|
||||
if (stride != 1 || index < offset || index >= offset + size)
|
||||
return failure();
|
||||
nextIndices.push_back(index - offset);
|
||||
}
|
||||
|
||||
auto resultType = dyn_cast<ShapedType>(extractSliceOp.getResult().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
currentIndices = std::move(nextIndices);
|
||||
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) {
|
||||
SmallVector<int64_t> nextIndices(currentIndices.size());
|
||||
SmallVector<int64_t> nextShape(currentShape.size());
|
||||
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) {
|
||||
int64_t sourceIndex = attr.getInt();
|
||||
nextIndices[destIndex] = currentIndices[sourceIndex];
|
||||
nextShape[destIndex] = currentShape[sourceIndex];
|
||||
}
|
||||
currentIndices = std::move(nextIndices);
|
||||
currentShape = std::move(nextShape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) {
|
||||
SmallVector<int64_t> nextIndices(currentIndices.size());
|
||||
SmallVector<int64_t> nextShape(currentShape.size());
|
||||
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange<IntegerAttr>())) {
|
||||
int64_t sourceIndex = attr.getInt();
|
||||
nextIndices[destIndex] = currentIndices[sourceIndex];
|
||||
nextShape[destIndex] = currentShape[sourceIndex];
|
||||
}
|
||||
currentIndices = std::move(nextIndices);
|
||||
currentShape = std::move(nextShape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<tensor::CastOp, tosa::ReshapeOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp>(op)) {
|
||||
if (failed(reshapeToResultShape(op)))
|
||||
return failure();
|
||||
continue;
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
mappedIndices.assign(currentIndices.begin(), currentIndices.end());
|
||||
return success();
|
||||
}
|
||||
|
||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (mapping.lookupOrNull(operand))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = operand.getDefiningOp();
|
||||
if (!definingOp)
|
||||
continue;
|
||||
|
||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
|
||||
IRMapping mapping;
|
||||
mapping.map(sourceValue, sourceValue);
|
||||
clonedValue = sourceValue;
|
||||
|
||||
rewriter.setInsertionPointAfterValue(sourceValue);
|
||||
for (Operation* op : helperChain) {
|
||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
clonedValue = clonedOp->getResult(0);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
}
|
||||
}
|
||||
|
||||
static Value emitHostCopy(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
Value outputTensor,
|
||||
Value sourceValue,
|
||||
int32_t hostTargetOffset,
|
||||
int32_t deviceSourceOffset,
|
||||
int32_t sizeInBytes) {
|
||||
return PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
sourceValue,
|
||||
rewriter.getI32IntegerAttr(hostTargetOffset),
|
||||
rewriter.getI32IntegerAttr(deviceSourceOffset),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void addReturnOutputBuffers(func::ReturnOp returnOp,
|
||||
IRRewriter& rewriter,
|
||||
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
|
||||
outputTensors.reserve(returnOp->getNumOperands());
|
||||
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
||||
Value currentReturnValue = returnValue;
|
||||
Operation* returnValueDefiningOp = currentReturnValue.getDefiningOp();
|
||||
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
||||
assert(!hasWeightAlways(returnValueDefiningOp));
|
||||
outputTensors.push_back(
|
||||
[currentReturnValue](IRRewriter& rewriter, Location loc) -> Value { return currentReturnValue; });
|
||||
}
|
||||
else {
|
||||
auto outRankedTensorType = llvm::dyn_cast<RankedTensorType>(currentReturnValue.getType());
|
||||
auto memRefType = MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
|
||||
|
||||
std::string outputBaseName = ("output_" + Twine(index)).str();
|
||||
std::string outputName = makeUniqueSymbolName(returnOp->getParentOfType<ModuleOp>(), outputBaseName);
|
||||
rewriter.setInsertionPoint(returnOp.getParentOp());
|
||||
memref::GlobalOp::create(rewriter,
|
||||
returnOp.getLoc(),
|
||||
rewriter.getStringAttr(outputName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(memRefType),
|
||||
{},
|
||||
{},
|
||||
{});
|
||||
outputTensors.push_back([memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
return toTensor.getResult();
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
auto yieldType = cast<TensorType>(yieldValue.getType());
|
||||
|
||||
if (auto returnUse = analyzeReturnUse(result)) {
|
||||
Value storedValue = yieldValue;
|
||||
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
|
||||
for (Operation* op : returnUse->helperChain)
|
||||
markOpToRemove(state, op);
|
||||
|
||||
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||
if (auto storedOp = storedValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
emitHostCopy(
|
||||
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize));
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
auto resultUses = result.getUses();
|
||||
if (rangeLength(resultUses) == 1) {
|
||||
OpOperand& resultUse = *resultUses.begin();
|
||||
Operation* resultUser = resultUse.getOwner();
|
||||
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
emitHostCopy(
|
||||
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
|
||||
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
||||
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||
markOpToRemove(state, concatOp);
|
||||
|
||||
if (concatReturnUse->helperChain.empty()) {
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
yieldValue,
|
||||
static_cast<int32_t>(flatOffset * elementSize),
|
||||
0,
|
||||
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
||||
if (!storedType) {
|
||||
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
||||
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
||||
for (auto [dim, idx] : llvm::enumerate(sourceIndices))
|
||||
sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx;
|
||||
|
||||
SmallVector<int64_t> destinationIndices;
|
||||
if (failed(mapIndicesThroughHelperChain(
|
||||
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
|
||||
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> extractOffsets;
|
||||
SmallVector<OpFoldResult> extractSizes;
|
||||
SmallVector<OpFoldResult> extractStrides;
|
||||
extractOffsets.reserve(storedType.getRank());
|
||||
extractSizes.reserve(storedType.getRank());
|
||||
extractStrides.reserve(storedType.getRank());
|
||||
for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) {
|
||||
extractOffsets.push_back(rewriter.getIndexAttr(idx));
|
||||
extractSizes.push_back(rewriter.getIndexAttr(1));
|
||||
extractStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
auto scalarTensorType =
|
||||
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
|
||||
auto elementSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
|
||||
rewriter.setInsertionPointAfter(elementSlice);
|
||||
|
||||
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
||||
outputTensor = emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
elementSlice.getResult(),
|
||||
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
||||
0,
|
||||
static_cast<int32_t>(elementSize));
|
||||
}
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
return ReturnPathLoweringResult::NotReturnPath;
|
||||
}
|
||||
|
||||
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
|
||||
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
||||
if (!op)
|
||||
return;
|
||||
|
||||
bool isExclusivelyOwnedByReturnChain = op->use_empty();
|
||||
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
||||
Operation* onlyUser = *op->getUsers().begin();
|
||||
isExclusivelyOwnedByReturnChain =
|
||||
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|
||||
|| isReturnHelperChainOp(onlyUser);
|
||||
}
|
||||
if (!isExclusivelyOwnedByReturnChain)
|
||||
return;
|
||||
|
||||
if (isReturnHelperChainOp(op)) {
|
||||
Value source = op->getOperand(0);
|
||||
markOpToRemove(state, op);
|
||||
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
markOpToRemove(state, computeOp);
|
||||
if (!computeOp.getInputs().empty())
|
||||
for (Value input : computeOp.getInputs())
|
||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
for (Value operand : concatOp.getOperands())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
}
|
||||
};
|
||||
|
||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||
auto loc = returnOp.getLoc();
|
||||
for (auto it : llvm::enumerate(originalOperands)) {
|
||||
size_t orderWithinReturn = it.index();
|
||||
Operation* returnOperand = it.value().getDefiningOp();
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
|
||||
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
||||
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
||||
|
||||
struct ReturnPathState {
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
||||
};
|
||||
|
||||
enum class ReturnPathLoweringResult {
|
||||
Handled,
|
||||
NotReturnPath,
|
||||
Failure
|
||||
};
|
||||
|
||||
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp,
|
||||
mlir::IRRewriter& rewriter,
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors);
|
||||
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||
mlir::OpResult result,
|
||||
mlir::Value yieldValue,
|
||||
ReturnPathState& state,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,113 @@
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
// Replaces concat-of-adjacent-slices with one packed slice to keep batch sends compact.
|
||||
struct FoldConcatOfContiguousSlices : OpRewritePattern<tensor::ConcatOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::ConcatOp op, PatternRewriter& rewriter) const override {
|
||||
if (op.getDim() != 0)
|
||||
return failure();
|
||||
|
||||
Value packed = createPackedExtractSliceTensor(op.getInputs(), rewriter, op.getLoc());
|
||||
if (!packed)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(op, packed);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||
packedShape[0] *= count;
|
||||
return RankedTensorType::get(packedShape, elementType.getElementType());
|
||||
}
|
||||
|
||||
Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) {
|
||||
if (values.empty())
|
||||
return {};
|
||||
if (values.size() == 1)
|
||||
return values.front();
|
||||
|
||||
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!firstSliceOp)
|
||||
return {};
|
||||
|
||||
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
|
||||
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
|
||||
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|
||||
|| firstType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
||||
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
};
|
||||
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
|
||||
return {};
|
||||
|
||||
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
|
||||
int64_t rowsPerValue = firstSizes[0];
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
for (size_t index = 1; index < values.size(); ++index) {
|
||||
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|
||||
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|
||||
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(sliceOp.getStaticStrides()))
|
||||
return {};
|
||||
|
||||
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
|
||||
return {};
|
||||
|
||||
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
|
||||
return {};
|
||||
|
||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
|
||||
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
|
||||
return {};
|
||||
}
|
||||
|
||||
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(firstType.getRank());
|
||||
sizes.reserve(firstType.getRank());
|
||||
strides.reserve(firstType.getRank());
|
||||
|
||||
offsets.push_back(builder.getIndexAttr(firstOffsets[0]));
|
||||
sizes.push_back(builder.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
|
||||
strides.push_back(builder.getIndexAttr(firstStrides[0]));
|
||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
|
||||
offsets.push_back(builder.getIndexAttr(firstOffsets[dim]));
|
||||
sizes.push_back(builder.getIndexAttr(firstSizes[dim]));
|
||||
strides.push_back(builder.getIndexAttr(firstStrides[dim]));
|
||||
}
|
||||
|
||||
bool coversWholeSource = packedType == sourceType;
|
||||
for (int64_t dim = 0; coversWholeSource && dim < sourceType.getRank(); ++dim)
|
||||
coversWholeSource = firstOffsets[dim] == 0 && firstStrides[dim] == 1;
|
||||
if (coversWholeSource)
|
||||
return firstSliceOp.getSource();
|
||||
|
||||
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<FoldConcatOfContiguousSlices>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
|
||||
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
|
||||
|
||||
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -113,6 +113,18 @@ def PimSendBatchOp : PimOp<"send_batch", []> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimSendTensorBatchOp : PimOp<"send_tensor_batch", []> {
|
||||
let summary = "Send equal contiguous chunks of one per-lane tensor from a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive a tensor from another core";
|
||||
|
||||
@@ -181,6 +193,28 @@ def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimReceiveTensorBatchOp : PimOp<"receive_tensor_batch", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive equal contiguous chunks into one per-lane tensor inside a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$outputBuffer,
|
||||
DenseI32ArrayAttr:$sourceCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from host memory into device memory";
|
||||
|
||||
|
||||
@@ -174,6 +174,33 @@ ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void PimSendTensorBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!targetCoreIds.empty())
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void PimSendTensorOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
@@ -275,6 +302,43 @@ ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimReceiveTensorBatchOp::print(OpAsmPrinter& printer) {
|
||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||
printer << " into ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOperand(getOutputBuffer());
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getOutputBuffer().getType());
|
||||
printer << " -> ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand outputBuffer;
|
||||
Type outputBufferType;
|
||||
Type outputType;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
|
||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
||||
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|
||||
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|
||||
|| parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!sourceCoreIds.empty())
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
|
||||
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimConcatOp::print(OpAsmPrinter& printer) {
|
||||
printer << " axis " << getAxis() << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
|
||||
@@ -46,12 +46,47 @@ static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRe
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
|
||||
if (coreIds.empty())
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
|
||||
if (!coreBatchOp)
|
||||
return op->emitError() << kind << " must be nested inside pim.core_batch";
|
||||
|
||||
int32_t laneCount = coreBatchOp.getLaneCount();
|
||||
if (laneCount <= 0)
|
||||
return op->emitError() << kind << " requires a positive parent laneCount";
|
||||
if (coreIds.size() % static_cast<size_t>(laneCount) != 0)
|
||||
return op->emitError() << kind << " core id count must be divisible by the parent laneCount";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor or memref";
|
||||
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % chunkCount != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult PimSendTensorBatchOp::verify() {
|
||||
return verifyTensorBatchCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveTensorOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
@@ -60,6 +95,15 @@ LogicalResult PimReceiveTensorOp::verify() {
|
||||
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveTensorBatchOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
return verifyTensorBatchCommunication(
|
||||
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimConcatOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
return PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
contiguousType,
|
||||
contiguousBuffer,
|
||||
memrefValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
FailureOr<Value>
|
||||
getBufferOrValue(RewriterBase& rewriter, Value value, const BufferizationOptions& options, BufferizationState& state) {
|
||||
if (isa<BufferLikeType>(value.getType()))
|
||||
return value;
|
||||
return getBuffer(rewriter, value, options, state);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
@@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
mlir::Value materializeContiguousMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||
|
||||
llvm::FailureOr<mlir::Value> getBufferOrValue(mlir::RewriterBase& rewriter,
|
||||
mlir::Value value,
|
||||
const mlir::bufferization::BufferizationOptions& options,
|
||||
mlir::bufferization::BufferizationState& state);
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
@@ -4,6 +4,8 @@ add_public_tablegen_target(PimBufferizationIncGen)
|
||||
|
||||
add_pim_library(OMPimBufferization
|
||||
PimBufferizationPass.cpp
|
||||
BufferizationUtils.hpp
|
||||
BufferizationUtils.cpp
|
||||
OpBufferizationInterfaces.hpp
|
||||
OpBufferizationInterfaces.cpp
|
||||
Common.hpp
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "OpBufferizationInterfaces.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
@@ -14,33 +15,6 @@ using namespace bufferization;
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
static Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
return PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
contiguousType,
|
||||
contiguousBuffer,
|
||||
memrefValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static FailureOr<Value>
|
||||
getBufferOrValue(RewriterBase& rewriter, Value value, const BufferizationOptions& options, BufferizationState& state) {
|
||||
if (isa<BufferLikeType>(value.getType()))
|
||||
return value;
|
||||
return getBuffer(rewriter, value, options, state);
|
||||
}
|
||||
|
||||
struct MemCopyHostToDevOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
@@ -201,6 +175,27 @@ struct ReceiveTensorOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveTensorBatchOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorBatchOpInterface, PimReceiveTensorBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto receiveOp = cast<PimReceiveTensorBatchOp>(op);
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveTensorBatchOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInterface, PimConcatOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -308,6 +303,31 @@ struct SendBatchOpInterface : BufferizableOpInterface::ExternalModel<SendBatchOp
|
||||
}
|
||||
};
|
||||
|
||||
struct SendTensorBatchOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<SendTensorBatchOpInterface, PimSendTensorBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto sendOp = cast<PimSendTensorBatchOp>(op);
|
||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimSendTensorBatchOp>(
|
||||
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface, PimCoreOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
@@ -623,9 +643,11 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
|
||||
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||
PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx);
|
||||
PimReceiveTensorBatchOp::attachInterface<ReceiveTensorBatchOpInterface>(*ctx);
|
||||
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
|
||||
PimSendOp::attachInterface<SendOpInterface>(*ctx);
|
||||
PimSendBatchOp::attachInterface<SendBatchOpInterface>(*ctx);
|
||||
PimSendTensorBatchOp::attachInterface<SendTensorBatchOpInterface>(*ctx);
|
||||
PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx);
|
||||
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
@@ -102,23 +102,6 @@ def SpatConcatOp : SpatOp<"concat", []> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatMapOp : SpatOp<"map", [SingleBlock]> {
|
||||
let summary = "Apply the same lane-local region to many independent tensors";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<SpatTensor>:$inputs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<SpatTensor>:$outputs
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Communication
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -156,22 +139,25 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendManyOp : SpatOp<"channel_send_many", []> {
|
||||
let summary = "Send multiple tensors through logical channels";
|
||||
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
|
||||
let summary = "Send equal contiguous chunks of one tensor through logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<SpatTensor>:$inputs
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> {
|
||||
let summary = "Receive multiple tensors from logical channels";
|
||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
@@ -180,11 +166,14 @@ def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> {
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<SpatTensor>:$outputs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
||||
@@ -201,18 +190,21 @@ def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelSendManyBatchOp : SpatOp<"channel_send_many_batch", []> {
|
||||
let summary = "Send multiple per-lane tensors through logical channels in a batch body";
|
||||
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
|
||||
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<SpatTensor>:$inputs
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
@@ -232,8 +224,8 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> {
|
||||
let summary = "Receive multiple per-lane tensors through logical channels in a batch body";
|
||||
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> {
|
||||
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
@@ -242,11 +234,14 @@ def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> {
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<SpatTensor>:$outputs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -129,9 +129,8 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
|
||||
return failure();
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs)) {
|
||||
if (parseCompressedOperandSequence(parser, inputs))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
@@ -151,46 +150,6 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatMapOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printer.printType(getInputs().front().getType());
|
||||
printer << " -> ";
|
||||
printer.printType(getOutputs().front().getType());
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult SpatMapOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
Type inputType;
|
||||
Type outputType;
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
return failure();
|
||||
if (inputs.empty())
|
||||
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|
||||
|| parser.parseArrow() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> inputTypes(inputs.size(), inputType);
|
||||
SmallVector<Type> outputTypes(inputs.size(), outputType);
|
||||
if (regionArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
Region* body = result.addRegion();
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
@@ -357,97 +316,6 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatChannelSendManyOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
}
|
||||
|
||||
ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
@@ -494,55 +362,6 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelSendManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
}
|
||||
|
||||
ParseResult SpatChannelSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
@@ -584,47 +403,5 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelReceiveManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -105,26 +105,28 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
return batchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
size_t valueCount) {
|
||||
static LogicalResult verifyTensorChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
if (channelIds.size() != valueCount)
|
||||
return op->emitError("channel metadata length must match the number of values");
|
||||
return success();
|
||||
}
|
||||
if (channelIds.empty())
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) {
|
||||
if (types.empty())
|
||||
return op->emitError() << kind << " must carry at least one value";
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor";
|
||||
|
||||
Type firstType = types.front();
|
||||
for (Type type : types.drop_front())
|
||||
if (type != firstType)
|
||||
return op->emitError() << kind << " values must all have the same type";
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -144,19 +146,33 @@ static LogicalResult verifyBatchChannelSizes(Operation* op,
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyBatchChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
size_t valueCount) {
|
||||
static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.size() != valueCount * static_cast<size_t>(*laneCount))
|
||||
return op->emitError("channel metadata length must match the number of values times parent laneCount");
|
||||
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
|
||||
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor";
|
||||
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % chunkCount != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -323,39 +339,6 @@ LogicalResult SpatConcatOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatMapOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
if (getOutputs().size() != getInputs().size())
|
||||
return emitError("number of outputs must match number of inputs");
|
||||
|
||||
Type inputType = getInputs().front().getType();
|
||||
for (Value input : getInputs().drop_front())
|
||||
if (input.getType() != inputType)
|
||||
return emitError("all inputs must have the same type");
|
||||
|
||||
Type outputType = getOutputs().front().getType();
|
||||
for (Value output : getOutputs().drop_front())
|
||||
if (output.getType() != outputType)
|
||||
return emitError("all outputs must have the same type");
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("body must have exactly one block argument");
|
||||
if (block.getArgument(0).getType() != inputType)
|
||||
return emitError("body block argument type must match input type");
|
||||
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return emitError("body must terminate with spat.yield");
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return emitError("body yield must produce exactly one value");
|
||||
if (yieldOp.getOperand(0).getType() != outputType)
|
||||
return emitError("body yield type must match output type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
if (block.mightHaveTerminator()) {
|
||||
@@ -397,40 +380,48 @@ LogicalResult SpatCompute::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendManyOp::verify() {
|
||||
if (failed(verifyManyChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many");
|
||||
LogicalResult SpatChannelSendTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
"channel_send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveManyOp::verify() {
|
||||
if (failed(verifyManyChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many");
|
||||
LogicalResult SpatChannelReceiveTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
"channel_receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendManyBatchOp::verify() {
|
||||
if (failed(verifyManyBatchChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many_batch");
|
||||
LogicalResult SpatChannelSendTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
"channel_send_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveManyBatchOp::verify() {
|
||||
if (failed(verifyManyBatchChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many_batch");
|
||||
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
"channel_receive_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatComputeBatch::verify() {
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
@@ -31,6 +32,8 @@ namespace {
|
||||
using SpatCompute = onnx_mlir::spatial::SpatCompute;
|
||||
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
|
||||
|
||||
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
|
||||
|
||||
struct VirtualNode {
|
||||
SmallVector<size_t, 4> originalComputeIndices;
|
||||
Weight weight = 0;
|
||||
@@ -719,11 +722,12 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
|
||||
size_t iteration = 0;
|
||||
bool debugCoarsening = isDcpCoarsenDebugEnabled();
|
||||
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
|
||||
size_t oldNodeCount = virtualGraph.nodes.size();
|
||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
||||
if (windowSchedule.mergeGroups.empty()) {
|
||||
if (oldNodeCount >= 200)
|
||||
if (debugCoarsening && oldNodeCount >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
||||
iteration,
|
||||
@@ -737,7 +741,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
std::vector<size_t> oldToNewNode;
|
||||
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
||||
return false;
|
||||
if (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200)
|
||||
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
||||
iteration,
|
||||
@@ -755,7 +759,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
|
||||
while (virtualGraph.nodes.size() > 1) {
|
||||
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
|
||||
if (virtualGraph.nodes.size() >= 200)
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
@@ -764,7 +768,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
iteration++;
|
||||
TimingInfo timing = computeTiming(virtualGraph);
|
||||
if (!timing.valid) {
|
||||
if (virtualGraph.nodes.size() >= 200)
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
@@ -776,7 +780,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
||||
|
||||
if (selectedNodes.size() < 2) {
|
||||
if (virtualGraph.nodes.size() >= 200)
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
||||
iteration,
|
||||
virtualGraph.nodes.size(),
|
||||
@@ -786,7 +790,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
|
||||
if (tryCoarsenSelectedNodes(selectedNodes))
|
||||
continue;
|
||||
if (virtualGraph.nodes.size() >= 200)
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
|
||||
@@ -59,11 +59,7 @@ struct DenseMapInfo<ComputeInstance> {
|
||||
static ComputeInstance getTombstoneKey() {
|
||||
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static unsigned getHashValue(const ComputeInstance& v) {
|
||||
return llvm::hash_combine(v.op, v.laneStart, v.laneCount);
|
||||
}
|
||||
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) {
|
||||
return a == b;
|
||||
}
|
||||
static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); }
|
||||
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; }
|
||||
};
|
||||
} // namespace llvm
|
||||
|
||||
@@ -38,9 +38,11 @@ void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += t
|
||||
void DcpProgressLogger::printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const {
|
||||
if (!logProgress)
|
||||
return;
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n",
|
||||
totalTasks, readyCount, maxCpuCount, xbarsCapacity);
|
||||
llvm::errs() << llvm::formatv("[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n",
|
||||
totalTasks,
|
||||
readyCount,
|
||||
maxCpuCount,
|
||||
xbarsCapacity);
|
||||
}
|
||||
|
||||
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
|
||||
@@ -72,18 +74,17 @@ void DcpProgressLogger::printProgress(
|
||||
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
|
||||
|
||||
bool done = completedTasks == totalTasks;
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
|
||||
completedTasks,
|
||||
totalTasks,
|
||||
percent,
|
||||
readyCount,
|
||||
cpuCount,
|
||||
maxCpuCount,
|
||||
xbarsUsed,
|
||||
xbarsAvailable,
|
||||
llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(),
|
||||
done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str());
|
||||
llvm::errs() << llvm::formatv("[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
|
||||
completedTasks,
|
||||
totalTasks,
|
||||
percent,
|
||||
readyCount,
|
||||
cpuCount,
|
||||
maxCpuCount,
|
||||
xbarsUsed,
|
||||
xbarsAvailable,
|
||||
llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(),
|
||||
done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str());
|
||||
lastProgressPrint = now;
|
||||
}
|
||||
|
||||
@@ -100,9 +101,7 @@ void DcpProgressLogger::printProgress(size_t, CPU, int, size_t, size_t, bool) {}
|
||||
|
||||
#endif
|
||||
|
||||
void dumpGraphDot(const std::vector<TaskDCP>& nodes,
|
||||
const std::vector<std::list<TaskDCP*>>& cpuTasks,
|
||||
CPU lastCpu) {
|
||||
void dumpGraphDot(const std::vector<TaskDCP>& nodes, const std::vector<std::list<TaskDCP*>>& cpuTasks, CPU lastCpu) {
|
||||
static int dumpIndex = 0;
|
||||
std::string outputDir = onnx_mlir::getOutputDir();
|
||||
if (outputDir.empty())
|
||||
|
||||
@@ -9,9 +9,9 @@
|
||||
#include "Task.hpp"
|
||||
#include "Utils.hpp"
|
||||
|
||||
// Uncomment to enable DCP progress logging and per-phase profiling during
|
||||
// development. When disabled the logger methods are no-ops and the helpers
|
||||
// compile away.
|
||||
// Define DCP_DEBUG_ENABLED locally when debugging DCP progress and per-phase
|
||||
// profiling. In normal builds the logger methods are no-ops and helpers compile
|
||||
// away.
|
||||
#define DCP_DEBUG_ENABLED
|
||||
|
||||
#ifdef DCP_DEBUG_ENABLED
|
||||
@@ -33,10 +33,11 @@ public:
|
||||
|
||||
void printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const;
|
||||
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
|
||||
void printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount,
|
||||
size_t xbarsUsed, size_t xbarsAvailable, bool force);
|
||||
void
|
||||
printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount, size_t xbarsUsed, size_t xbarsAvailable, bool force);
|
||||
|
||||
#ifdef DCP_DEBUG_ENABLED
|
||||
|
||||
private:
|
||||
static std::string formatDuration(double seconds);
|
||||
|
||||
@@ -51,8 +52,6 @@ private:
|
||||
#endif
|
||||
};
|
||||
|
||||
void dumpGraphDot(const std::vector<TaskDCP>& nodes,
|
||||
const std::vector<std::list<TaskDCP*>>& cpuTasks,
|
||||
CPU lastCpu);
|
||||
void dumpGraphDot(const std::vector<TaskDCP>& nodes, const std::vector<std::list<TaskDCP*>>& cpuTasks, CPU lastCpu);
|
||||
|
||||
} // namespace dcp_graph
|
||||
|
||||
@@ -149,14 +149,6 @@ static SmallVector<int32_t> getMaterializedBatchCoreIds(size_t startCpu, size_t
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getBatchCoreIds(Operation* op, size_t laneCount) {
|
||||
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
if (auto coreIdAttr = op->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return SmallVector<int32_t>(laneCount, static_cast<int32_t>(coreIdAttr.getInt()));
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(coreIdAttr.getInt());
|
||||
@@ -245,312 +237,6 @@ static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
|
||||
}
|
||||
|
||||
static void sinkChannelsIntoBatchComputes(func::FuncOp funcOp,
|
||||
IRRewriter& rewriter,
|
||||
SmallVectorImpl<Operation*>& opsToErase,
|
||||
int64_t& nextChannelId) {
|
||||
SmallVector<SpatComputeBatch> batches(funcOp.getOps<SpatComputeBatch>());
|
||||
|
||||
for (auto batch : batches) {
|
||||
if (batch.getInputs().empty() && batch.getResults().empty())
|
||||
continue;
|
||||
|
||||
if (batch.getInputs().size() != static_cast<size_t>(batch.getLaneCount()))
|
||||
continue;
|
||||
if (batch.getResults().size() != static_cast<size_t>(batch.getLaneCount()))
|
||||
continue;
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveOp> inputReceives;
|
||||
inputReceives.reserve(batch.getInputs().size());
|
||||
bool allInputsAreReceives = true;
|
||||
for (Value input : batch.getInputs()) {
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
|
||||
if (!receiveOp) {
|
||||
allInputsAreReceives = false;
|
||||
break;
|
||||
}
|
||||
inputReceives.push_back(receiveOp);
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatChannelSendOp> resultSends;
|
||||
resultSends.reserve(batch.getResults().size());
|
||||
bool allResultsAreSingleSends = true;
|
||||
for (Value result : batch.getResults()) {
|
||||
if (!result.hasOneUse()) {
|
||||
allResultsAreSingleSends = false;
|
||||
break;
|
||||
}
|
||||
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(*result.getUsers().begin());
|
||||
if (!sendOp) {
|
||||
allResultsAreSingleSends = false;
|
||||
break;
|
||||
}
|
||||
resultSends.push_back(sendOp);
|
||||
}
|
||||
|
||||
if (!allInputsAreReceives || !allResultsAreSingleSends)
|
||||
continue;
|
||||
|
||||
Block& oldBlock = batch.getBody().front();
|
||||
if (oldBlock.getNumArguments() != 1)
|
||||
continue;
|
||||
|
||||
SmallVector<Value> newWeights(batch.getWeights().begin(), batch.getWeights().end());
|
||||
rewriter.setInsertionPointAfter(batch);
|
||||
auto newBatch = SpatComputeBatch::create(rewriter,
|
||||
batch.getLoc(),
|
||||
TypeRange {},
|
||||
rewriter.getI32IntegerAttr(batch.getLaneCount()),
|
||||
ValueRange(newWeights),
|
||||
ValueRange {});
|
||||
newBatch.getProperties().setOperandSegmentSizes({static_cast<int>(newWeights.size()), 0});
|
||||
|
||||
SmallVector<int32_t> coreIds = getBatchCoreIds(batch, static_cast<size_t>(batch.getLaneCount()));
|
||||
if (!coreIds.empty())
|
||||
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef<Location> {});
|
||||
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
struct BatchReceiveEntry {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
};
|
||||
SmallVector<BatchReceiveEntry> receiveEntries;
|
||||
receiveEntries.reserve(inputReceives.size());
|
||||
for (auto receiveOp : inputReceives)
|
||||
receiveEntries.push_back({receiveOp.getChannelId(), receiveOp.getSourceCoreId(), receiveOp.getTargetCoreId()});
|
||||
llvm::stable_sort(receiveEntries, [](const BatchReceiveEntry& lhs, const BatchReceiveEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> receiveChannelIds;
|
||||
SmallVector<int32_t> receiveSourceCoreIds;
|
||||
SmallVector<int32_t> receiveTargetCoreIds;
|
||||
receiveChannelIds.reserve(receiveEntries.size());
|
||||
receiveSourceCoreIds.reserve(receiveEntries.size());
|
||||
receiveTargetCoreIds.reserve(receiveEntries.size());
|
||||
for (const BatchReceiveEntry& entry : receiveEntries) {
|
||||
(void) entry;
|
||||
receiveChannelIds.push_back(nextChannelId++);
|
||||
receiveSourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
receiveTargetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
|
||||
batch.getLoc(),
|
||||
oldBlock.getArgument(0).getType(),
|
||||
rewriter.getDenseI64ArrayAttr(receiveChannelIds),
|
||||
rewriter.getDenseI32ArrayAttr(receiveSourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(receiveTargetCoreIds));
|
||||
|
||||
IRMapping mapper;
|
||||
mapper.map(oldBlock.getArgument(0), batchReceive.getOutput());
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : oldBlock) {
|
||||
if (&op == oldYield)
|
||||
continue;
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
|
||||
Value sendInput = mapper.lookup(oldYield.getOperand(0));
|
||||
struct BatchSendEntry {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
};
|
||||
SmallVector<BatchSendEntry> sendEntries;
|
||||
sendEntries.reserve(resultSends.size());
|
||||
for (auto sendOp : resultSends)
|
||||
sendEntries.push_back({sendOp.getChannelId(), sendOp.getSourceCoreId(), sendOp.getTargetCoreId()});
|
||||
llvm::stable_sort(sendEntries, [](const BatchSendEntry& lhs, const BatchSendEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> sendChannelIds;
|
||||
SmallVector<int32_t> sendSourceCoreIds;
|
||||
SmallVector<int32_t> sendTargetCoreIds;
|
||||
sendChannelIds.reserve(sendEntries.size());
|
||||
sendSourceCoreIds.reserve(sendEntries.size());
|
||||
sendTargetCoreIds.reserve(sendEntries.size());
|
||||
for (const BatchSendEntry& entry : sendEntries) {
|
||||
(void) entry;
|
||||
sendChannelIds.push_back(nextChannelId++);
|
||||
sendSourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
sendTargetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||
batch.getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(sendChannelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sendSourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(sendTargetCoreIds),
|
||||
sendInput);
|
||||
spatial::SpatYieldOp::create(rewriter, batch.getLoc(), ValueRange {});
|
||||
|
||||
for (auto receiveOp : inputReceives)
|
||||
opsToErase.push_back(receiveOp);
|
||||
for (auto sendOp : resultSends)
|
||||
opsToErase.push_back(sendOp);
|
||||
opsToErase.push_back(batch);
|
||||
}
|
||||
}
|
||||
|
||||
void sinkChannelsIntoComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||
SmallVector<Operation*> opsToErase;
|
||||
|
||||
for (auto compute : computes) {
|
||||
SmallVector<unsigned> keptInputIndices;
|
||||
SmallVector<unsigned> keptResultIndices;
|
||||
SmallVector<spatial::SpatChannelReceiveOp> internalizedReceives(compute.getInputs().size());
|
||||
SmallVector<SmallVector<spatial::SpatChannelSendOp>> resultSendOps(compute.getNumResults());
|
||||
|
||||
bool needsRewrite = false;
|
||||
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
|
||||
if (!receiveOp) {
|
||||
keptInputIndices.push_back(inputIndex);
|
||||
continue;
|
||||
}
|
||||
|
||||
internalizedReceives[inputIndex] = receiveOp;
|
||||
opsToErase.push_back(receiveOp);
|
||||
needsRewrite = true;
|
||||
}
|
||||
|
||||
for (auto [resultIndex, result] : llvm::enumerate(compute.getResults())) {
|
||||
bool hasNonSendUser = false;
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(user)) {
|
||||
resultSendOps[resultIndex].push_back(sendOp);
|
||||
opsToErase.push_back(sendOp);
|
||||
needsRewrite = true;
|
||||
continue;
|
||||
}
|
||||
hasNonSendUser = true;
|
||||
}
|
||||
|
||||
if (hasNonSendUser || resultSendOps[resultIndex].empty())
|
||||
keptResultIndices.push_back(resultIndex);
|
||||
}
|
||||
|
||||
if (!needsRewrite)
|
||||
continue;
|
||||
|
||||
SmallVector<Value> newOperands;
|
||||
SmallVector<Type> newResultTypes;
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
newOperands.reserve(compute.getNumOperands());
|
||||
newResultTypes.reserve(keptResultIndices.size());
|
||||
newBlockArgTypes.reserve(keptInputIndices.size());
|
||||
newBlockArgLocs.reserve(keptInputIndices.size());
|
||||
|
||||
for (Value weight : compute.getWeights())
|
||||
newOperands.push_back(weight);
|
||||
for (unsigned inputIndex : keptInputIndices) {
|
||||
Value input = compute.getInputs()[inputIndex];
|
||||
newOperands.push_back(input);
|
||||
newBlockArgTypes.push_back(input.getType());
|
||||
newBlockArgLocs.push_back(compute.getLoc());
|
||||
}
|
||||
for (unsigned resultIndex : keptResultIndices)
|
||||
newResultTypes.push_back(compute.getResult(resultIndex).getType());
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
auto newCompute =
|
||||
SpatCompute::create(rewriter, compute.getLoc(), TypeRange(newResultTypes), ValueRange(newOperands));
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(keptInputIndices.size())});
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, coreIdAttr);
|
||||
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [mappedIndex, inputIndex] : llvm::enumerate(keptInputIndices))
|
||||
mapper.map(compute.getBody().front().getArgument(inputIndex), newBlock->getArgument(mappedIndex));
|
||||
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
for (auto [inputIndex, receiveOp] : llvm::enumerate(internalizedReceives)) {
|
||||
if (!receiveOp)
|
||||
continue;
|
||||
|
||||
auto internalReceive = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||
receiveOp.getLoc(),
|
||||
receiveOp.getResult().getType(),
|
||||
receiveOp.getChannelIdAttr(),
|
||||
receiveOp.getSourceCoreIdAttr(),
|
||||
receiveOp.getTargetCoreIdAttr());
|
||||
mapper.map(compute.getBody().front().getArgument(inputIndex), internalReceive.getResult());
|
||||
}
|
||||
|
||||
auto oldYieldOp = cast<spatial::SpatYieldOp>(compute.getBody().front().getTerminator());
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : compute.getBody().front()) {
|
||||
if (&op == oldYieldOp)
|
||||
continue;
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
|
||||
for (auto [resultIndex, sendOps] : llvm::enumerate(resultSendOps)) {
|
||||
if (sendOps.empty())
|
||||
continue;
|
||||
|
||||
Value yieldedValue = mapper.lookup(oldYieldOp.getOperand(resultIndex));
|
||||
for (auto sendOp : sendOps)
|
||||
spatial::SpatChannelSendOp::create(rewriter,
|
||||
sendOp.getLoc(),
|
||||
sendOp.getChannelIdAttr(),
|
||||
sendOp.getSourceCoreIdAttr(),
|
||||
sendOp.getTargetCoreIdAttr(),
|
||||
yieldedValue);
|
||||
}
|
||||
|
||||
SmallVector<Value> keptYieldOperands;
|
||||
keptYieldOperands.reserve(keptResultIndices.size());
|
||||
for (unsigned resultIndex : keptResultIndices)
|
||||
keptYieldOperands.push_back(mapper.lookup(oldYieldOp.getOperand(resultIndex)));
|
||||
spatial::SpatYieldOp::create(rewriter, oldYieldOp.getLoc(), ValueRange(keptYieldOperands));
|
||||
|
||||
for (auto [newResultIndex, oldResultIndex] : llvm::enumerate(keptResultIndices))
|
||||
compute.getResult(oldResultIndex).replaceAllUsesWith(newCompute.getResult(newResultIndex));
|
||||
|
||||
opsToErase.push_back(compute);
|
||||
}
|
||||
|
||||
sinkChannelsIntoBatchComputes(funcOp, rewriter, opsToErase, nextChannelId);
|
||||
|
||||
SmallVector<Operation*> pendingRemovals(opsToErase.begin(), opsToErase.end());
|
||||
while (!pendingRemovals.empty()) {
|
||||
bool erasedAny = false;
|
||||
for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) {
|
||||
if (!(*it)->use_empty()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
rewriter.eraseOp(*it);
|
||||
it = pendingRemovals.erase(it);
|
||||
erasedAny = true;
|
||||
}
|
||||
|
||||
if (erasedAny)
|
||||
continue;
|
||||
|
||||
for (Operation* op : pendingRemovals)
|
||||
op->emitError("failed to sink channel op into compute");
|
||||
llvm_unreachable("channel sinking left cyclic top-level dependencies");
|
||||
}
|
||||
}
|
||||
|
||||
void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||
@@ -1280,7 +966,8 @@ public:
|
||||
|
||||
void runOnOperation() override {
|
||||
mergeTriviallyConnectedComputes(getOperation());
|
||||
emitMotifProfile(getOperation());
|
||||
if (std::getenv("DCP_MOTIF_PROFILE"))
|
||||
emitMotifProfile(getOperation());
|
||||
|
||||
func::FuncOp func = getOperation();
|
||||
Location loc = func.getLoc();
|
||||
@@ -1718,17 +1405,12 @@ public:
|
||||
for (Operation* user : result.getUsers())
|
||||
remainingUsers.push_back(user);
|
||||
if (!remainingUsers.empty()) {
|
||||
llvm::errs() << "[MergeComputeNodesPass] refusing to erase op with remaining uses: " << op->getName() << "\n";
|
||||
llvm::errs() << " erase-set: " << (allOpsToErase.contains(op) ? "yes" : "no") << "\n";
|
||||
op->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions());
|
||||
llvm::errs() << "\n";
|
||||
InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup")
|
||||
<< "; erase-set=" << (allOpsToErase.contains(op) ? "yes" : "no");
|
||||
for (Operation* user : remainingUsers) {
|
||||
llvm::errs() << " user: " << user->getName()
|
||||
<< " erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no") << "\n";
|
||||
user->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions());
|
||||
llvm::errs() << "\n";
|
||||
diagnostic.attachNote(user->getLoc())
|
||||
<< "remaining user " << user->getName() << "; erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no");
|
||||
}
|
||||
op->emitOpError("still has uses during per-cpu merge cleanup");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
@@ -40,6 +41,176 @@ struct RegularChunk {
|
||||
Value output;
|
||||
};
|
||||
|
||||
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||
packedShape[0] *= count;
|
||||
return RankedTensorType::get(packedShape, elementType.getElementType());
|
||||
}
|
||||
|
||||
static Value
|
||||
extractPackedChunk(Value packedValue, RankedTensorType chunkType, unsigned index, IRRewriter& rewriter, Location loc) {
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(chunkType.getRank());
|
||||
sizes.reserve(chunkType.getRank());
|
||||
strides.reserve(chunkType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
|
||||
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(0)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(dim)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
static Value createPackedExtractRowsSlice(
|
||||
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
||||
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
int64_t rowsPerValue = rowType.getDimSize(0);
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(inputType.getRank());
|
||||
sizes.reserve(inputType.getRank());
|
||||
strides.reserve(inputType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
||||
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static Value createPackedExtractSliceTensor(ValueRange values, IRRewriter& rewriter, Location loc) {
|
||||
if (values.empty())
|
||||
return {};
|
||||
if (values.size() == 1)
|
||||
return values.front();
|
||||
|
||||
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!firstSliceOp)
|
||||
return {};
|
||||
|
||||
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
|
||||
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
|
||||
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|
||||
|| firstType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
||||
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
};
|
||||
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
|
||||
return {};
|
||||
|
||||
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
|
||||
int64_t rowsPerValue = firstSizes[0];
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
for (size_t index = 1; index < values.size(); ++index) {
|
||||
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|
||||
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|
||||
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(sliceOp.getStaticStrides()))
|
||||
return {};
|
||||
|
||||
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
|
||||
return {};
|
||||
|
||||
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
|
||||
return {};
|
||||
|
||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
|
||||
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
|
||||
return {};
|
||||
}
|
||||
|
||||
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(firstType.getRank());
|
||||
sizes.reserve(firstType.getRank());
|
||||
strides.reserve(firstType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(firstOffsets[0]));
|
||||
sizes.push_back(rewriter.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
|
||||
strides.push_back(rewriter.getIndexAttr(firstStrides[0]));
|
||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(firstOffsets[dim]));
|
||||
sizes.push_back(rewriter.getIndexAttr(firstSizes[dim]));
|
||||
strides.push_back(rewriter.getIndexAttr(firstStrides[dim]));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
|
||||
if (values.empty())
|
||||
return false;
|
||||
|
||||
auto firstResult = dyn_cast<OpResult>(values.front());
|
||||
if (!firstResult)
|
||||
return false;
|
||||
|
||||
owner = firstResult.getOwner();
|
||||
startIndex = firstResult.getResultNumber();
|
||||
for (auto [index, value] : llvm::enumerate(values)) {
|
||||
auto result = dyn_cast<OpResult>(value);
|
||||
if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) {
|
||||
if (values.empty())
|
||||
return {};
|
||||
if (Value packedSlice = createPackedExtractSliceTensor(values, rewriter, loc))
|
||||
return packedSlice;
|
||||
|
||||
Operation* owner = nullptr;
|
||||
unsigned startIndex = 0;
|
||||
if (getContiguousOpResults(values, owner, startIndex))
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
|
||||
return createPackedExtractRowsSlice(
|
||||
extractRowsOp, startIndex, static_cast<unsigned>(values.size()), rewriter, loc);
|
||||
|
||||
auto firstType = dyn_cast<RankedTensorType>(values.front().getType());
|
||||
if (!firstType || !firstType.hasStaticShape() || firstType.getRank() == 0)
|
||||
return {};
|
||||
if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; }))
|
||||
return {};
|
||||
|
||||
return tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, values).getResult();
|
||||
}
|
||||
|
||||
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
|
||||
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand
|
||||
&& lhs.resultType == rhs.resultType;
|
||||
@@ -89,45 +260,97 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
return chunk;
|
||||
}
|
||||
|
||||
static void buildRegularMapBody(spatial::SpatMapOp mapOp, const RegularChunk& anchorChunk, IRRewriter& rewriter) {
|
||||
auto* block = rewriter.createBlock(
|
||||
&mapOp.getBody(), mapOp.getBody().end(), TypeRange {anchorChunk.input.getType()}, {anchorChunk.startOp->getLoc()});
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
|
||||
IRMapping mapping;
|
||||
mapping.map(anchorChunk.input, block->getArgument(0));
|
||||
|
||||
for (Operation* op : anchorChunk.ops) {
|
||||
Operation* cloned = rewriter.clone(*op, mapping);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults()))
|
||||
mapping.map(oldResult, newResult);
|
||||
}
|
||||
|
||||
spatial::SpatYieldOp::create(
|
||||
rewriter, anchorChunk.startOp->getLoc(), ValueRange {mapping.lookup(anchorChunk.output)});
|
||||
}
|
||||
|
||||
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
||||
assert(!run.empty() && "expected a non-empty regular chunk run");
|
||||
const RegularChunk& anchorChunk = run.front();
|
||||
|
||||
SmallVector<Value> inputs;
|
||||
SmallVector<Type> outputTypes;
|
||||
inputs.reserve(run.size());
|
||||
outputTypes.reserve(run.size());
|
||||
for (const RegularChunk& chunk : run) {
|
||||
for (const RegularChunk& chunk : run)
|
||||
inputs.push_back(chunk.input);
|
||||
outputTypes.push_back(chunk.output.getType());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(anchorChunk.startOp);
|
||||
auto mapOp =
|
||||
spatial::SpatMapOp::create(rewriter, anchorChunk.startOp->getLoc(), TypeRange(outputTypes), ValueRange(inputs));
|
||||
buildRegularMapBody(mapOp, anchorChunk, rewriter);
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
|
||||
if (!packedInput)
|
||||
return;
|
||||
|
||||
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
|
||||
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
|
||||
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(run.size()));
|
||||
auto packedInit = tensor::EmptyOp::create(
|
||||
rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), run.size());
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 1);
|
||||
auto loop =
|
||||
scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Block* loopBlock = loop.getBody();
|
||||
rewriter.setInsertionPointToStart(loopBlock);
|
||||
Value iv = loopBlock->getArgument(0);
|
||||
Value acc = loopBlock->getArgument(1);
|
||||
|
||||
Value inputRowOffset = iv;
|
||||
if (inputType.getDimSize(0) != 1) {
|
||||
auto rowsPerValue =
|
||||
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), inputType.getDimSize(0));
|
||||
inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> extractOffsets;
|
||||
SmallVector<OpFoldResult> extractSizes;
|
||||
SmallVector<OpFoldResult> extractStrides;
|
||||
extractOffsets.push_back(inputRowOffset);
|
||||
extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(0)));
|
||||
extractStrides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
extractOffsets.push_back(rewriter.getIndexAttr(0));
|
||||
extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||
extractStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
auto inputSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, anchorChunk.startOp->getLoc(), inputType, packedInput, extractOffsets, extractSizes, extractStrides);
|
||||
|
||||
IRMapping mapping;
|
||||
mapping.map(anchorChunk.input, inputSlice.getResult());
|
||||
for (Operation* op : anchorChunk.ops) {
|
||||
Operation* cloned = rewriter.clone(*op, mapping);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults()))
|
||||
mapping.map(oldResult, newResult);
|
||||
}
|
||||
|
||||
Value mappedOutput = mapping.lookup(anchorChunk.output);
|
||||
Value outputRowOffset = iv;
|
||||
if (outputType.getDimSize(0) != 1) {
|
||||
auto rowsPerValue =
|
||||
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), outputType.getDimSize(0));
|
||||
outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> insertOffsets;
|
||||
SmallVector<OpFoldResult> insertSizes;
|
||||
SmallVector<OpFoldResult> insertStrides;
|
||||
insertOffsets.push_back(outputRowOffset);
|
||||
insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(0)));
|
||||
insertStrides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < outputType.getRank(); ++dim) {
|
||||
insertOffsets.push_back(rewriter.getIndexAttr(0));
|
||||
insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim)));
|
||||
insertStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
auto inserted = tensor::InsertSliceOp::create(
|
||||
rewriter, anchorChunk.startOp->getLoc(), mappedOutput, acc, insertOffsets, insertSizes, insertStrides);
|
||||
scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult());
|
||||
}
|
||||
|
||||
for (auto [index, chunk] : llvm::enumerate(run)) {
|
||||
Value replacement = extractPackedChunk(
|
||||
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
|
||||
Value output = chunk.output;
|
||||
output.replaceAllUsesWith(mapOp.getResult(index));
|
||||
output.replaceAllUsesWith(replacement);
|
||||
}
|
||||
|
||||
SmallVector<Operation*> opsToErase;
|
||||
@@ -178,28 +401,29 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Type> outputTypes;
|
||||
channelIds.reserve(sortedEntries.size());
|
||||
sourceCoreIds.reserve(sortedEntries.size());
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
outputTypes.reserve(sortedEntries.size());
|
||||
for (ReceiveEntry& entry : sortedEntries) {
|
||||
(void) entry;
|
||||
channelIds.push_back(nextChannelId++);
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
outputTypes.push_back(entry.op.getOutput().getType());
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
TypeRange(outputTypes),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
packedType,
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
|
||||
entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex));
|
||||
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
@@ -255,17 +479,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
spatial::SpatChannelSendManyOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
ValueRange(inputs));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
||||
if (packedInput) {
|
||||
spatial::SpatChannelSendTensorOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
packedInput);
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
continue;
|
||||
it = runIt;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -297,25 +524,25 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Type> outputTypes;
|
||||
outputTypes.reserve(run.size());
|
||||
for (auto op : run) {
|
||||
llvm::append_range(channelIds, op.getChannelIds());
|
||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||
outputTypes.push_back(op.getOutput().getType());
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveManyBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
TypeRange(outputTypes),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
packedType,
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
for (auto [index, op] : llvm::enumerate(run))
|
||||
op.getOutput().replaceAllUsesWith(compactReceive.getResult(index));
|
||||
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
@@ -352,17 +579,20 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
spatial::SpatChannelSendManyBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
ValueRange(inputs));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
||||
if (packedInput) {
|
||||
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
packedInput);
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
continue;
|
||||
it = runIt;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -31,7 +31,8 @@ struct DenseSubviewKeyInfo {
|
||||
|
||||
static unsigned getHashValue(const DenseSubviewKey& key) {
|
||||
return static_cast<unsigned>(
|
||||
llvm::hash_combine(key.source, llvm::hash_combine_range(key.offsets.begin(), key.offsets.end()),
|
||||
llvm::hash_combine(key.source,
|
||||
llvm::hash_combine_range(key.offsets.begin(), key.offsets.end()),
|
||||
llvm::hash_combine_range(key.resultShape.begin(), key.resultShape.end())));
|
||||
}
|
||||
|
||||
@@ -98,16 +99,16 @@ memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
alignment);
|
||||
}
|
||||
|
||||
FailureOr<DenseElementsAttr> foldDenseSubview(DenseElementsAttr denseAttr,
|
||||
ArrayRef<int64_t> staticOffsets,
|
||||
ArrayRef<int64_t> resultShape) {
|
||||
FailureOr<DenseElementsAttr>
|
||||
foldDenseSubview(DenseElementsAttr denseAttr, ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> resultShape) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() != static_cast<int64_t>(staticOffsets.size())
|
||||
|| sourceType.getRank() != static_cast<int64_t>(resultShape.size()))
|
||||
return failure();
|
||||
|
||||
static DenseMap<DenseSubviewKey, DenseElementsAttr, DenseSubviewKeyInfo> cache;
|
||||
DenseSubviewKey key {denseAttr, SmallVector<int64_t>(staticOffsets.begin(), staticOffsets.end()),
|
||||
DenseSubviewKey key {denseAttr,
|
||||
SmallVector<int64_t>(staticOffsets.begin(), staticOffsets.end()),
|
||||
SmallVector<int64_t>(resultShape.begin(), resultShape.end())};
|
||||
if (auto cached = cache.find(key); cached != cache.end())
|
||||
return cached->second;
|
||||
@@ -152,6 +153,30 @@ FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value)
|
||||
return denseAttr;
|
||||
}
|
||||
|
||||
FailureOr<DenseElementsAttr> foldDenseSourceToType(ModuleOp moduleOp, Value source, MemRefType resultType) {
|
||||
auto srcSubview = getStaticSubviewInfo(source);
|
||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(source);
|
||||
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
if (succeeded(srcSubview)) {
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
return foldDenseSubview(*denseAttr, *staticOffsets, resultType.getShape());
|
||||
}
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(resultType.getShape(), resultType.getElementType());
|
||||
if (resultTensorType != denseAttr->getType())
|
||||
return failure();
|
||||
return *denseAttr;
|
||||
}
|
||||
|
||||
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
value = stripMemRefViewOps(value);
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
|
||||
@@ -36,6 +36,9 @@ llvm::FailureOr<mlir::DenseElementsAttr> foldDenseSubview(mlir::DenseElementsAtt
|
||||
|
||||
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
|
||||
|
||||
llvm::FailureOr<mlir::DenseElementsAttr>
|
||||
foldDenseSourceToType(mlir::ModuleOp moduleOp, mlir::Value source, mlir::MemRefType resultType);
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
|
||||
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||
|
||||
@@ -90,6 +90,7 @@ static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
|
||||
return attr;
|
||||
}
|
||||
|
||||
// Folds constant linalg fills inside cores into private globals plus device copies.
|
||||
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -249,6 +250,7 @@ static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, M
|
||||
return DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
}
|
||||
|
||||
// Folds transposes of constant globals so weight-only transposes stay host-side.
|
||||
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -304,11 +306,9 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
|
||||
|
||||
bool isAlwaysWeight =
|
||||
!transposeOp->getUsers().empty()
|
||||
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) {
|
||||
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||
});
|
||||
bool isAlwaysWeight = !transposeOp->getUsers().empty()
|
||||
&& llvm::all_of(transposeOp->getUsers(),
|
||||
[](Operation* user) { return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); });
|
||||
if (isAlwaysWeight) {
|
||||
markWeightAlways(newGlobal);
|
||||
markWeightAlways(newGetGlobal);
|
||||
@@ -330,6 +330,7 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
||||
}
|
||||
};
|
||||
|
||||
// Collapses fill-and-copy allocation chains into one folded constant global.
|
||||
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -367,9 +368,8 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
}
|
||||
|
||||
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
|
||||
return llvm::all_of(castOp->getUsers(), [](Operation* user) {
|
||||
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||
});
|
||||
return llvm::all_of(castOp->getUsers(),
|
||||
[](Operation* user) { return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); });
|
||||
})) {
|
||||
allLiveUsersAreCoreOps = false;
|
||||
}
|
||||
@@ -417,6 +417,7 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts host copies from dense globals into direct folded globals.
|
||||
struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -431,37 +432,14 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
|
||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
|
||||
|
||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||
if (failed(denseAttr))
|
||||
auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType);
|
||||
if (failed(foldedAttr))
|
||||
return failure();
|
||||
|
||||
DenseElementsAttr foldedAttr;
|
||||
if (succeeded(srcSubview)) {
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
|
||||
if (failed(maybeFoldedAttr))
|
||||
return failure();
|
||||
foldedAttr = *maybeFoldedAttr;
|
||||
}
|
||||
else {
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
if (resultTensorType != denseAttr->getType())
|
||||
return failure();
|
||||
foldedAttr = *denseAttr;
|
||||
}
|
||||
|
||||
bool allLiveUsersAreCores = true;
|
||||
for (Operation* user : allocOp->getUsers()) {
|
||||
if (user == copyOp)
|
||||
@@ -477,7 +455,7 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_host_copy");
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_host_copy");
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
@@ -494,6 +472,7 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts PIM copies from dense globals into direct folded globals before codegen.
|
||||
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -511,37 +490,14 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0)
|
||||
return failure();
|
||||
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
|
||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
|
||||
|
||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||
if (failed(denseAttr))
|
||||
auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType);
|
||||
if (failed(foldedAttr))
|
||||
return failure();
|
||||
|
||||
DenseElementsAttr foldedAttr;
|
||||
if (succeeded(srcSubview)) {
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
|
||||
if (failed(maybeFoldedAttr))
|
||||
return failure();
|
||||
foldedAttr = *maybeFoldedAttr;
|
||||
}
|
||||
else {
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
if (resultTensorType != denseAttr->getType())
|
||||
return failure();
|
||||
foldedAttr = *denseAttr;
|
||||
}
|
||||
|
||||
bool allLiveUsersAreCores = true;
|
||||
for (Operation* user : allocOp->getUsers()) {
|
||||
if (user == copyOp)
|
||||
@@ -557,7 +513,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_memcp");
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
@@ -577,13 +533,11 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
} // namespace
|
||||
|
||||
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
||||
patterns
|
||||
.add<FoldConstantTransposePattern,
|
||||
FoldConstantAllocPattern,
|
||||
FoldConstantCoreMapPattern,
|
||||
FoldConstantHostCopyPattern,
|
||||
FoldConstantMemCpPattern>(
|
||||
patterns.getContext());
|
||||
patterns.add<FoldConstantTransposePattern,
|
||||
FoldConstantAllocPattern,
|
||||
FoldConstantCoreMapPattern,
|
||||
FoldConstantHostCopyPattern,
|
||||
FoldConstantMemCpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -128,6 +128,7 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
// Splits core copies through subviews into contiguous copy chunks for codegen.
|
||||
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -162,6 +163,7 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
|
||||
}
|
||||
};
|
||||
|
||||
// Splits host-to-device subview loads into contiguous copy chunks.
|
||||
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -193,6 +195,7 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
|
||||
}
|
||||
};
|
||||
|
||||
// Splits device-to-host subview stores into contiguous copy chunks.
|
||||
struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDevToHostOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -224,6 +227,7 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDe
|
||||
}
|
||||
};
|
||||
|
||||
// Folds constant subviews used as core weights into standalone globals.
|
||||
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
@@ -9,6 +10,8 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -21,6 +24,8 @@ namespace {
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 0;
|
||||
return false;
|
||||
@@ -33,6 +38,91 @@ static int64_t getValueSizeInBytes(Value value) {
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
template <typename CoreOpTy>
|
||||
static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter, bool& hasFailure) {
|
||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||
SmallVector<Operation*> ops;
|
||||
coreOp.getBody().front().walk([&](Operation* op) {
|
||||
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
|
||||
ops.push_back(op);
|
||||
});
|
||||
|
||||
for (Operation* op : ops) {
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value originalValue = operand.get();
|
||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(originalValue);
|
||||
if (failed(resolvedAddress))
|
||||
continue;
|
||||
|
||||
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
if (!getGlobalOp)
|
||||
continue;
|
||||
|
||||
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
|
||||
if (!originalType || !originalType.hasStaticShape()) {
|
||||
op->emitOpError("host constant materialization requires a static memref operand");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& cachedByOffset = materializedValues[resolvedAddress->base];
|
||||
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
|
||||
auto cachedValue = cachedByType.find(originalType);
|
||||
if (cachedValue != cachedByType.end()) {
|
||||
operand.set(cachedValue->second);
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t totalBytes = getValueSizeInBytes(originalValue);
|
||||
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
|
||||
op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType);
|
||||
Value deviceDst = localAlloc;
|
||||
if (contiguousType != originalType)
|
||||
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
|
||||
|
||||
Value copiedValue;
|
||||
if constexpr (std::is_same_v<CoreOpTy, pim::PimCoreBatchOp>) {
|
||||
copiedValue = pim::PimMemCopyHostToDevBatchOp::create(
|
||||
rewriter,
|
||||
op->getLoc(),
|
||||
originalType,
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
||||
.getOutput();
|
||||
}
|
||||
else {
|
||||
copiedValue = pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
op->getLoc(),
|
||||
originalType,
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
cachedByType[originalType] = copiedValue;
|
||||
operand.set(copiedValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
|
||||
|
||||
@@ -50,71 +140,11 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
|
||||
materializeHostConstantsInCore(coreOp, rewriter, hasFailure);
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) {
|
||||
if (isa<pim::PimHaltOp>(op))
|
||||
continue;
|
||||
|
||||
for (OpOperand& operand : op.getOpOperands()) {
|
||||
Value originalValue = operand.get();
|
||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(originalValue);
|
||||
if (failed(resolvedAddress))
|
||||
continue;
|
||||
|
||||
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
if (!getGlobalOp)
|
||||
continue;
|
||||
|
||||
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
|
||||
if (!originalType || !originalType.hasStaticShape()) {
|
||||
op.emitOpError("host constant materialization requires a static memref operand");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& cachedByOffset = materializedValues[resolvedAddress->base];
|
||||
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
|
||||
auto cachedValue = cachedByType.find(originalType);
|
||||
if (cachedValue != cachedByType.end()) {
|
||||
operand.set(cachedValue->second);
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t totalBytes = getValueSizeInBytes(originalValue);
|
||||
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
|
||||
op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(&op);
|
||||
Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType);
|
||||
Value deviceDst = localAlloc;
|
||||
if (contiguousType != originalType)
|
||||
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
|
||||
|
||||
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
op.getLoc(),
|
||||
originalType,
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
|
||||
|
||||
cachedByType[originalType] = hostToDevCopy.getResult();
|
||||
operand.set(hostToDevCopy.getResult());
|
||||
}
|
||||
}
|
||||
}
|
||||
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
|
||||
materializeHostConstantsInCore(coreBatchOp, rewriter, hasFailure);
|
||||
|
||||
SmallVector<Operation*> hostCompactOps;
|
||||
for (Operation& op : funcOp.getBody().front())
|
||||
|
||||
@@ -122,6 +122,14 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
||||
ModuleOp moduleOp = getOperation();
|
||||
bool hasFailure = false;
|
||||
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
if (op->getDialect()->getNamespace() != "spat")
|
||||
return;
|
||||
|
||||
op->emitError("illegal Spatial operation reached PIM codegen verification");
|
||||
hasFailure = true;
|
||||
});
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
@@ -33,8 +33,6 @@ def _parse_pim_pass_timings(output_text):
|
||||
pass_timings[label] = pass_timings.get(label, 0.0) + duration
|
||||
break
|
||||
|
||||
if not pass_timings:
|
||||
raise RuntimeError("Raptor timing report did not contain any PIM pass timings.")
|
||||
return pass_timings
|
||||
|
||||
|
||||
@@ -43,7 +41,7 @@ def _format_command(cmd):
|
||||
|
||||
|
||||
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||
crossbar_size, crossbar_count, core_count=None, cwd=None, reporter=None):
|
||||
crossbar_size, crossbar_count, core_count=None, cwd=None, verbose=False, reporter=None):
|
||||
# Define the arguments, with the possibility to set crossbar size and count
|
||||
args = [
|
||||
network_path,
|
||||
@@ -51,13 +49,13 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||
output_base,
|
||||
"--maccel=PIM",
|
||||
"--EmitPimCodegen",
|
||||
# "--use-experimental-conv-impl=true",
|
||||
f"--crossbar-size={crossbar_size}",
|
||||
f"--crossbar-count={crossbar_count}",
|
||||
"--enable-timing",
|
||||
]
|
||||
if core_count is not None:
|
||||
args.append(f"--core-count={core_count}")
|
||||
if verbose:
|
||||
args.append("--enable-timing")
|
||||
|
||||
cmd = [str(raptor_onnx_path)] + [str(arg) for arg in args]
|
||||
if reporter is not None:
|
||||
|
||||
@@ -47,7 +47,9 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
error_output = captured_output if not stream_output else recent_output
|
||||
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output))
|
||||
exc = subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output))
|
||||
exc.output_already_streamed = stream_output and bool(captured_output)
|
||||
raise exc
|
||||
return bytes(captured_output)
|
||||
|
||||
|
||||
@@ -67,15 +69,15 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
|
||||
|
||||
stream_output = bool(getattr(reporter, "verbose", False))
|
||||
if not stream_output:
|
||||
process = subprocess.Popen(
|
||||
completed = subprocess.run(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
output = _stream_output(process.stdout.fileno(), process, reporter, stream_output=False)
|
||||
return output.decode("utf-8", errors="replace") if capture_output else None
|
||||
if completed.returncode != 0:
|
||||
raise subprocess.CalledProcessError(completed.returncode, completed.args, output=completed.stdout)
|
||||
return completed.stdout.decode("utf-8", errors="replace") if capture_output else None
|
||||
|
||||
try:
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
|
||||
+10
-7
@@ -27,7 +27,9 @@ def print_validation_error(reporter, rel, exc):
|
||||
file=sys.stderr, flush=True)
|
||||
if isinstance(exc, subprocess.CalledProcessError):
|
||||
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
|
||||
if exc.output:
|
||||
if getattr(exc, "output_already_streamed", False):
|
||||
print("Failure log already printed above.", file=sys.stderr, flush=True)
|
||||
elif exc.output:
|
||||
output_text = exc.output.decode("utf-8", errors="replace") if isinstance(exc.output, bytes) else str(exc.output)
|
||||
if output_text:
|
||||
print(output_text, file=sys.stderr, end="" if output_text.endswith("\n") else "\n", flush=True)
|
||||
@@ -160,12 +162,13 @@ def main():
|
||||
Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
|
||||
print(f"| {rel.ljust(path_width)} | {status} |")
|
||||
print(separator)
|
||||
print_average_pim_pass_timings(
|
||||
pass_timing_sums,
|
||||
pass_timing_counts,
|
||||
total_timing_sum,
|
||||
timed_benchmark_count,
|
||||
)
|
||||
if a.verbose:
|
||||
print_average_pim_pass_timings(
|
||||
pass_timing_sums,
|
||||
pass_timing_counts,
|
||||
total_timing_sum,
|
||||
timed_benchmark_count,
|
||||
)
|
||||
|
||||
sys.exit(0 if n_passed == n_total else 1)
|
||||
|
||||
|
||||
+28
-23
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
@@ -11,7 +10,6 @@ from raptor import compile_with_raptor
|
||||
from gen_network_runner import gen_network_runner
|
||||
from subprocess_utils import run_command_with_reporter
|
||||
|
||||
|
||||
STAGE_TITLES = (
|
||||
"Compile ONNX",
|
||||
"Build Runner",
|
||||
@@ -48,10 +46,12 @@ class ProgressReporter:
|
||||
self.verbose = verbose
|
||||
self.columns = shutil.get_terminal_size((100, 20)).columns
|
||||
self.suspended = False
|
||||
self.rendered_width = 0
|
||||
|
||||
def _clear(self):
|
||||
if self.enabled:
|
||||
sys.stdout.write("\033[2K\r")
|
||||
sys.stdout.write("\r" + (" " * self.rendered_width) + "\r")
|
||||
sys.stdout.flush()
|
||||
|
||||
def _render(self):
|
||||
if not self.enabled or self.suspended:
|
||||
@@ -70,16 +70,16 @@ class ProgressReporter:
|
||||
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
|
||||
|
||||
counts = (
|
||||
" "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.GREEN
|
||||
+ f"P:{self.passed_models}"
|
||||
+ Style.RESET_ALL
|
||||
+ " "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.RED
|
||||
+ f"F:{self.failed_models}"
|
||||
+ Style.RESET_ALL
|
||||
" "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.GREEN
|
||||
+ f"P:{self.passed_models}"
|
||||
+ Style.RESET_ALL
|
||||
+ " "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.RED
|
||||
+ f"F:{self.failed_models}"
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
model_counter = ""
|
||||
label = ""
|
||||
@@ -92,9 +92,12 @@ class ProgressReporter:
|
||||
|
||||
available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3)
|
||||
label = label[:available_label_width]
|
||||
|
||||
sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL)
|
||||
plain_line = prefix_text + model_counter + f" P:{self.passed_models} F:{self.failed_models}" + label
|
||||
rendered_line = prefix + model_counter + counts + label + Style.RESET_ALL
|
||||
padded_width = max(self.rendered_width, len(plain_line))
|
||||
sys.stdout.write("\r" + rendered_line + (" " * max(0, padded_width - len(plain_line))))
|
||||
sys.stdout.flush()
|
||||
self.rendered_width = len(plain_line)
|
||||
|
||||
def log(self, message="", color=None):
|
||||
if not self.verbose:
|
||||
@@ -124,18 +127,19 @@ class ProgressReporter:
|
||||
self._render()
|
||||
|
||||
def suspend(self):
|
||||
if self.enabled:
|
||||
self._clear()
|
||||
self.suspended = True
|
||||
self._clear()
|
||||
sys.stdout.flush()
|
||||
|
||||
def resume(self):
|
||||
self.suspended = False
|
||||
self._render()
|
||||
|
||||
def finish(self):
|
||||
if self.enabled:
|
||||
self.suspended = True
|
||||
self._clear()
|
||||
sys.stdout.flush()
|
||||
self.rendered_width = 0
|
||||
|
||||
|
||||
def run_command(cmd, cwd=None, reporter=None):
|
||||
@@ -212,7 +216,8 @@ def build_dump_ranges(config_path, outputs_descriptor):
|
||||
|
||||
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
|
||||
run_command(
|
||||
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
|
||||
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator",
|
||||
"--",
|
||||
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
|
||||
cwd=simulator_dir,
|
||||
reporter=reporter,
|
||||
@@ -293,7 +298,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
reporter.advance()
|
||||
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
|
||||
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", verbose=False)
|
||||
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c",
|
||||
verbose=False)
|
||||
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter)
|
||||
print_info(reporter, f"Runner built at {runner_path}")
|
||||
reporter.advance()
|
||||
@@ -316,9 +322,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
||||
pim_pass_timings = compile_with_raptor(
|
||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
||||
crossbar_size, crossbar_count, core_count=core_count,
|
||||
cwd=raptor_dir, reporter=reporter)
|
||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, crossbar_size, crossbar_count,
|
||||
core_count=core_count, cwd=raptor_dir, verbose=verbose, reporter=reporter)
|
||||
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
|
||||
reporter.advance()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user