merge remote changes
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -3,6 +3,7 @@
|
|||||||
**/.vscode
|
**/.vscode
|
||||||
|
|
||||||
.claude
|
.claude
|
||||||
|
.codex
|
||||||
AGENTS.md
|
AGENTS.md
|
||||||
|
|
||||||
CMakeUserPresets.json
|
CMakeUserPresets.json
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|||||||
@@ -5,9 +5,11 @@
|
|||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
#include "llvm/Support/FileSystem.h"
|
#include "llvm/Support/FileSystem.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
@@ -55,9 +57,23 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
|||||||
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
||||||
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
||||||
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
||||||
|
SmallVector<mlir::Value> args;
|
||||||
|
|
||||||
|
|
||||||
|
for (mlir::Value arg : funcOp.getArguments()){
|
||||||
|
gatherMemEntry(arg);
|
||||||
|
args.push_back(arg);
|
||||||
|
}
|
||||||
|
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
if (!hasWeightAlways(getGlobalOp)) {
|
if (!hasWeightAlways(getGlobalOp)) {
|
||||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (globalMemrefOp.getName().starts_with("arg")){
|
||||||
|
StringRef indexStr = globalMemrefOp.getName().substr(4);
|
||||||
|
int index = 0;
|
||||||
|
llvm::to_integer(indexStr,index, 10);
|
||||||
|
globalAliases.push_back({getGlobalOp.getResult(), args[index]});
|
||||||
|
}
|
||||||
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
||||||
if (inserted)
|
if (inserted)
|
||||||
gatherMemEntry(getGlobalOp.getResult());
|
gatherMemEntry(getGlobalOp.getResult());
|
||||||
@@ -66,8 +82,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
for (mlir::Value arg : funcOp.getArguments())
|
|
||||||
gatherMemEntry(arg);
|
|
||||||
|
|
||||||
funcOp.walk([&](memref::AllocOp allocOp) {
|
funcOp.walk([&](memref::AllocOp allocOp) {
|
||||||
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
||||||
@@ -133,6 +147,12 @@ json::Object PimCodeGen::createEmptyOffset() {
|
|||||||
return offset;
|
return offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t PimCodeGen::remapCoreId(size_t coreId) const {
|
||||||
|
auto it = emittedCoreIds.find(coreId);
|
||||||
|
assert(it != emittedCoreIds.end() && "Missing emitted core id remapping");
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
static json::Object createRs1OnlyOffset() {
|
static json::Object createRs1OnlyOffset() {
|
||||||
json::Object offset;
|
json::Object offset;
|
||||||
offset["offset_select"] = 1;
|
offset["offset_select"] = 1;
|
||||||
@@ -192,7 +212,7 @@ void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t
|
|||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = opName;
|
json["op"] = opName;
|
||||||
json["rd"] = 0;
|
json["rd"] = 0;
|
||||||
json["core"] = coreId;
|
json["core"] = remapCoreId(coreId);
|
||||||
json["size"] = size;
|
json["size"] = size;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
@@ -414,6 +434,9 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
|
|||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {
|
||||||
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
|
||||||
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
|
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
|
||||||
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
|
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
|
||||||
@@ -583,6 +606,29 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
|
|||||||
return scalarCore;
|
return scalarCore;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void aliasMaterializedHostGlobals(
|
||||||
|
ModuleOp moduleOp, func::FuncOp funcOp, pim::PimCoreOp coreOp, PimAcceleratorMemory& memory) {
|
||||||
|
coreOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
|
if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(getGlobalOp.getResult()))
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (!targetGlobal)
|
||||||
|
return;
|
||||||
|
|
||||||
|
mlir::Value aliasedValue;
|
||||||
|
funcOp.walk([&](memref::GetGlobalOp candidate) {
|
||||||
|
if (aliasedValue || candidate == getGlobalOp || !memory.memEntriesMap.contains(candidate.getResult()))
|
||||||
|
return;
|
||||||
|
if (lookupGlobalForGetGlobal(moduleOp, candidate) == targetGlobal)
|
||||||
|
aliasedValue = candidate.getResult();
|
||||||
|
});
|
||||||
|
|
||||||
|
if (aliasedValue)
|
||||||
|
memory.memEntriesMap[getGlobalOp.getResult()] = memory.memEntriesMap[aliasedValue];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
/// Write global constant data into a binary memory image at their allocated addresses.
|
/// Write global constant data into a binary memory image at their allocated addresses.
|
||||||
static OnnxMlirCompilerErrorCodes
|
static OnnxMlirCompilerErrorCodes
|
||||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||||
@@ -677,6 +723,8 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
|
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
|
||||||
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
|
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
|
||||||
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
|
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
|
||||||
|
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
|
||||||
|
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
|
||||||
else {
|
else {
|
||||||
op.emitError("Unsupported codegen for this operation");
|
op.emitError("Unsupported codegen for this operation");
|
||||||
op.dump();
|
op.dump();
|
||||||
@@ -880,13 +928,14 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
|||||||
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
|
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
|
||||||
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
||||||
PimAcceleratorMemory& memory,
|
PimAcceleratorMemory& memory,
|
||||||
size_t coreCount,
|
size_t maxCoreId,
|
||||||
json::Object xbarsPerArrayGroup,
|
json::Object xbarsPerArrayGroup,
|
||||||
StringRef outputDirPath) {
|
StringRef outputDirPath) {
|
||||||
json::Object configJson;
|
json::Object configJson;
|
||||||
|
|
||||||
// +1 because pimsim-nn also considers the host as a core
|
// pimsim-nn indexes cores directly by their numeric core ID, with the host
|
||||||
configJson["core_cnt"] = coreCount + 1;
|
// occupying core 0.
|
||||||
|
configJson["core_cnt"] = maxCoreId + 1;
|
||||||
|
|
||||||
// TODO: Should this be based on the floating point type used in the model?
|
// TODO: Should this be based on the floating point type used in the model?
|
||||||
// The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision
|
// The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision
|
||||||
@@ -960,12 +1009,31 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
// For each core, specify the number of crossbar per array group.
|
// For each core, specify the number of crossbar per array group.
|
||||||
// This implementation always assigns one crossbar per group.
|
// This implementation always assigns one crossbar per group.
|
||||||
json::Object xbarsPerArrayGroup;
|
json::Object xbarsPerArrayGroup;
|
||||||
size_t coreCount = 0;
|
size_t maxCoreId = 0;
|
||||||
|
|
||||||
// Create Weight Folder
|
// Create Weight Folder
|
||||||
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
||||||
|
|
||||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||||
|
llvm::DenseMap<size_t, size_t> emittedCoreIds;
|
||||||
|
size_t nextEmittedCoreId = 1;
|
||||||
|
|
||||||
|
for (Operation* op : coreLikeOps) {
|
||||||
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||||
|
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
|
||||||
|
if (!emittedCoreIds.contains(originalCoreId))
|
||||||
|
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||||
|
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
|
||||||
|
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||||
|
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||||
|
if (!emittedCoreIds.contains(originalCoreId))
|
||||||
|
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (Operation* op : coreLikeOps) {
|
for (Operation* op : coreLikeOps) {
|
||||||
SmallVector<pim::PimCoreOp> scalarCores;
|
SmallVector<pim::PimCoreOp> scalarCores;
|
||||||
@@ -979,8 +1047,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (pim::PimCoreOp coreOp : scalarCores) {
|
for (pim::PimCoreOp coreOp : scalarCores) {
|
||||||
auto coreId = coreOp.getCoreId();
|
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
|
||||||
coreCount++;
|
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||||
|
maxCoreId = std::max(maxCoreId, coreId);
|
||||||
|
|
||||||
std::error_code errorCode;
|
std::error_code errorCode;
|
||||||
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
|
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
|
||||||
@@ -991,7 +1060,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
}
|
}
|
||||||
coreFileStream << '[';
|
coreFileStream << '[';
|
||||||
|
|
||||||
PimCodeGen coreCodeGen(memory, coreFileStream);
|
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
|
||||||
|
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
|
||||||
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
|
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
|
||||||
|
|
||||||
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
||||||
@@ -1009,7 +1079,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
return InvalidOutputFileAccess;
|
return InvalidOutputFileAccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& mapWeightToFile = mapCoreWeightToFileName[static_cast<size_t>(coreId)];
|
auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId];
|
||||||
json::Array xbarsPerGroup;
|
json::Array xbarsPerGroup;
|
||||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||||
if (index >= coreOp.getWeights().size()) {
|
if (index >= coreOp.getWeights().size()) {
|
||||||
@@ -1037,5 +1107,5 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
coreOp.erase();
|
coreOp.erase();
|
||||||
}
|
}
|
||||||
|
|
||||||
return writeConfigJson(funcOp, memory, coreCount, std::move(xbarsPerArrayGroup), outputDirPath);
|
return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
|
|
||||||
@@ -58,10 +59,12 @@ public:
|
|||||||
class PimCodeGen {
|
class PimCodeGen {
|
||||||
PimAcceleratorMemory& memory;
|
PimAcceleratorMemory& memory;
|
||||||
llvm::raw_fd_ostream& coreFileStream;
|
llvm::raw_fd_ostream& coreFileStream;
|
||||||
|
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
||||||
|
|
||||||
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||||
return memory.getValueAddress(value, knowledge);
|
return memory.getValueAddress(value, knowledge);
|
||||||
}
|
}
|
||||||
|
size_t remapCoreId(size_t coreId) const;
|
||||||
|
|
||||||
static llvm::json::Object createEmptyOffset();
|
static llvm::json::Object createEmptyOffset();
|
||||||
void emitInstruction(llvm::json::Object instruction) const;
|
void emitInstruction(llvm::json::Object instruction) const;
|
||||||
@@ -83,8 +86,10 @@ class PimCodeGen {
|
|||||||
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
|
PimCodeGen(PimAcceleratorMemory& memory,
|
||||||
: memory(memory), coreFileStream(coreJson) {}
|
llvm::raw_fd_ostream& coreJson,
|
||||||
|
const llvm::DenseMap<size_t, size_t>& emittedCoreIds)
|
||||||
|
: memory(memory), coreFileStream(coreJson), emittedCoreIds(emittedCoreIds) {}
|
||||||
|
|
||||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
||||||
@@ -106,6 +111,7 @@ public:
|
|||||||
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
@@ -11,6 +12,7 @@
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@@ -183,6 +185,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
|
|
||||||
encapsulateGlobalInstruction(*entryFunc);
|
encapsulateGlobalInstruction(*entryFunc);
|
||||||
|
|
||||||
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
||||||
@@ -199,19 +202,36 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
|
|||||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
||||||
Value source = funcSource(toRemoveOp);
|
Value source = funcSource(toRemoveOp);
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||||
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
rewriter.setInsertionPointToEnd(BB);
|
||||||
rewriter.setInsertionPointToEnd(BB);
|
IRMapping mapper;
|
||||||
IRMapping mapper;
|
mapper.map(source, BB->getArgument(0));
|
||||||
mapper.map(source, BB->getArgument(0));
|
auto newInst = rewriter.clone(*inst, mapper);
|
||||||
auto newInst = rewriter.clone(*inst, mapper);
|
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
inst->replaceAllUsesWith(newCompute->getResults());
|
||||||
inst->replaceAllUsesWith(newCompute->getResults());
|
inst->erase();
|
||||||
inst->erase();
|
return true;
|
||||||
return true;
|
}
|
||||||
}
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||||
|
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
|
||||||
|
auto source = toRemoveOp.getSource();
|
||||||
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||||
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||||
|
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||||
|
rewriter.setInsertionPointToEnd(BB);
|
||||||
|
IRMapping mapper;
|
||||||
|
mapper.map(source, BB->getArgument(0));
|
||||||
|
auto newInst = rewriter.clone(*inst, mapper);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||||
|
inst->replaceAllUsesWith(newCompute->getResults());
|
||||||
|
inst->erase();
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -245,6 +265,24 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
|||||||
inst->erase();
|
inst->erase();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||||
|
SmallVector<Type> sourceTypes;
|
||||||
|
SmallVector<Location> sourceLoc;
|
||||||
|
for (auto source : sources) {
|
||||||
|
sourceTypes.push_back(source.getType());
|
||||||
|
sourceLoc.push_back(loc);
|
||||||
|
}
|
||||||
|
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
||||||
|
rewriter.setInsertionPointToEnd(BB);
|
||||||
|
IRMapping mapper;
|
||||||
|
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||||
|
mapper.map(source, bbArg);
|
||||||
|
auto newConcat = rewriter.clone(*inst, mapper);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
||||||
|
inst->replaceAllUsesWith(newCompute->getResults());
|
||||||
|
inst->erase();
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -306,6 +344,89 @@ static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewrite
|
|||||||
return cast<Value>(mapped);
|
return cast<Value>(mapped);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool sourceOpernadHasWeightAlways(Operation* op) {
|
||||||
|
if (op == nullptr)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
Operation* source = nullptr;
|
||||||
|
do {
|
||||||
|
|
||||||
|
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch>(*op)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(*op)) {
|
||||||
|
auto tmpSource = extractSliceOp.getSource();
|
||||||
|
auto definingOp = tmpSource.getDefiningOp();
|
||||||
|
if (definingOp)
|
||||||
|
op = definingOp;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(*op)) {
|
||||||
|
auto tmpSource = extractRowsOp.getInput();
|
||||||
|
auto definingOp = tmpSource.getDefiningOp();
|
||||||
|
if (definingOp)
|
||||||
|
op = definingOp;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(*op)) {
|
||||||
|
auto tmpSource = expandShapeOp.getSrc();
|
||||||
|
auto definingOp = tmpSource.getDefiningOp();
|
||||||
|
if (definingOp)
|
||||||
|
op = definingOp;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto transposeOp = dyn_cast<ONNXTransposeOp>(*op)) {
|
||||||
|
auto tmpSource = transposeOp.getData();
|
||||||
|
auto definingOp = tmpSource.getDefiningOp();
|
||||||
|
if (definingOp)
|
||||||
|
op = definingOp;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(*op)) {
|
||||||
|
auto tmpSource = collapseShapeOp.getSrc();
|
||||||
|
auto definingOp = tmpSource.getDefiningOp();
|
||||||
|
if (definingOp)
|
||||||
|
op = definingOp;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto constantOp = dyn_cast<arith::ConstantOp>(*op)) {
|
||||||
|
source = constantOp;
|
||||||
|
}
|
||||||
|
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(*op)) {
|
||||||
|
bool res = false;
|
||||||
|
for (auto operand : concatOp.getOperands()) {
|
||||||
|
res |= hasWeightAlways(operand.getDefiningOp());
|
||||||
|
if (res)
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
else if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(*op)) {
|
||||||
|
bool res = false;
|
||||||
|
for (auto operand : concatOp.getOperands()) {
|
||||||
|
res |= hasWeightAlways(operand.getDefiningOp());
|
||||||
|
if (res)
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
op->dump();
|
||||||
|
llvm_unreachable("Global instruction not handle in func");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (source == nullptr);
|
||||||
|
|
||||||
|
if (hasWeightAlways(source))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO what we want to keep in global?
|
// TODO what we want to keep in global?
|
||||||
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
@@ -314,8 +435,14 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|||||||
while (keep) {
|
while (keep) {
|
||||||
keep = false;
|
keep = false;
|
||||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||||
keep |= encapsulator<tensor::ExtractSliceOp>(
|
|
||||||
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
|
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
|
||||||
|
instruction)
|
||||||
|
|| isa<func::ReturnOp>(instruction)
|
||||||
|
|| sourceOpernadHasWeightAlways(&instruction))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
keep |= encapsulateSlice(rewriter, loc, &instruction);
|
||||||
|
|
||||||
keep |= encapsulator<tensor::ExpandShapeOp>(
|
keep |= encapsulator<tensor::ExpandShapeOp>(
|
||||||
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
||||||
|
|||||||
@@ -23,7 +23,10 @@ static Value extractSliceAt(
|
|||||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||||
sizes[axis] = rewriter.getIndexAttr(size);
|
sizes[axis] = rewriter.getIndexAttr(size);
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
SmallVector<int64_t> resultShape(inputType.getShape());
|
||||||
|
resultShape[axis] = size;
|
||||||
|
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Split : OpConversionPattern<ONNXSplitOp> {
|
struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||||
@@ -49,12 +52,7 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
|||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
int64_t sliceSize = resultType.getShape()[axis];
|
int64_t sliceSize = resultType.getShape()[axis];
|
||||||
auto computeOp =
|
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
|
||||||
createSpatCompute<1>(rewriter, splitOp.getLoc(), TypeRange {resultType}, {}, adaptor.getInput(), [&](Value x) {
|
|
||||||
Value output = extractSliceAt(x, axis, offset, sliceSize, rewriter, splitOp.getLoc());
|
|
||||||
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), output);
|
|
||||||
});
|
|
||||||
outputs.push_back(computeOp.getResult(0));
|
|
||||||
offset += sliceSize;
|
offset += sliceSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ add_public_tablegen_target(SpatialToPimIncGen)
|
|||||||
add_pim_library(OMSpatialToPim
|
add_pim_library(OMSpatialToPim
|
||||||
SpatialToPimPass.cpp
|
SpatialToPimPass.cpp
|
||||||
Common.cpp
|
Common.cpp
|
||||||
|
Patterns.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
385
src/PIM/Conversion/SpatialToPim/Patterns.cpp
Normal file
385
src/PIM/Conversion/SpatialToPim/Patterns.cpp
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
#include "llvm/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
|
||||||
|
Location loc = extractSliceOp.getLoc();
|
||||||
|
|
||||||
|
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
for (auto& uses : extractSliceOp->getUses()) {
|
||||||
|
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
||||||
|
auto spatCompute = cast<spatial::SpatCompute>(uses.getOwner());
|
||||||
|
if (spatCompute.getInputs().empty())
|
||||||
|
return failure();
|
||||||
|
if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::DenseMap<Operation*, Value> mapSpatToExtract;
|
||||||
|
|
||||||
|
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
||||||
|
|
||||||
|
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
||||||
|
auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||||
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
|
|
||||||
|
if (BBArgValue.use_empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
||||||
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
|
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute.getOperation()]);
|
||||||
|
spatCompute.getInputsMutable().erase(BBArgIndex);
|
||||||
|
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||||
|
auto BBArgIndex = uses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
||||||
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||||
|
|
||||||
|
if (BBArgValue.use_empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
|
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
||||||
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
|
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||||
|
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||||
|
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
||||||
|
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
||||||
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
{
|
||||||
|
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
||||||
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
|
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
|
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
||||||
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
|
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
||||||
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
|
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||||
|
uses.set(mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||||
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.eraseOp(extractSliceOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
|
||||||
|
static int i = 0;
|
||||||
|
Location loc = constantOp.getLoc();
|
||||||
|
|
||||||
|
if (hasWeightAlways(constantOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!isa<func::FuncOp>(constantOp->getParentOp()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
|
||||||
|
if (isa<spatial::SpatCompute>(op))
|
||||||
|
return false;
|
||||||
|
if (isa<func::FuncOp>(op->getParentOp()))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
|
||||||
|
|
||||||
|
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
|
||||||
|
|
||||||
|
if (constRankedTensorType) {
|
||||||
|
mlir::MemRefType memRefType =
|
||||||
|
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
|
||||||
|
std::string argName = "const_" + std::to_string(i++);
|
||||||
|
memref::GlobalOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
rewriter.getStringAttr(argName),
|
||||||
|
rewriter.getStringAttr("private"),
|
||||||
|
TypeAttr::get(memRefType),
|
||||||
|
constantOp.getValueAttr(),
|
||||||
|
rewriter.getUnitAttr(),
|
||||||
|
{});
|
||||||
|
|
||||||
|
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||||
|
|
||||||
|
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||||
|
auto constUsers = constUses.getOwner();
|
||||||
|
|
||||||
|
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||||
|
|
||||||
|
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||||
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||||
|
spatCompute.getInputsMutable().erase(BBArgIndex);
|
||||||
|
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
|
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||||
|
|
||||||
|
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
||||||
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||||
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
|
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||||
|
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||||
|
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
||||||
|
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
||||||
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
{
|
||||||
|
|
||||||
|
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
|
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
||||||
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
|
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||||
|
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||||
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
||||||
|
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||||
|
|
||||||
|
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||||
|
auto constUsers = constUses.getOwner();
|
||||||
|
|
||||||
|
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||||
|
|
||||||
|
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||||
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
|
||||||
|
spatCompute.getInputsMutable().erase(BBArgIndex);
|
||||||
|
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
|
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||||
|
|
||||||
|
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
||||||
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||||
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||||
|
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
|
||||||
|
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
||||||
|
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
||||||
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||||
|
if (!mapSpatComputeToConst.contains(parent)) {
|
||||||
|
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
||||||
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
|
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
|
||||||
|
}
|
||||||
|
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
|
||||||
|
assert(batchParent && "Global Constant used direcly not within a compute");
|
||||||
|
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
|
||||||
|
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
|
||||||
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
|
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
|
||||||
|
}
|
||||||
|
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto parent = constantOp->getParentOp();
|
||||||
|
rewriter.eraseOp(constantOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override {
|
||||||
|
|
||||||
|
if (funcOp.getArguments().empty())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (llvm::all_of(funcOp.getArguments(),
|
||||||
|
[](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Location loc = funcOp.getLoc();
|
||||||
|
|
||||||
|
for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) {
|
||||||
|
if (arg.getUses().empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(funcOp.getOperation());
|
||||||
|
|
||||||
|
assert(isa<mlir::RankedTensorType>(arg.getType()));
|
||||||
|
|
||||||
|
auto argRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(arg.getType());
|
||||||
|
mlir::MemRefType memRefType =
|
||||||
|
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
|
||||||
|
|
||||||
|
std::string argName = "arg_" + std::to_string(index);
|
||||||
|
|
||||||
|
memref::GlobalOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
rewriter.getStringAttr(argName),
|
||||||
|
rewriter.getStringAttr("private"),
|
||||||
|
TypeAttr::get(memRefType),
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{});
|
||||||
|
|
||||||
|
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
||||||
|
auto argUser = argUses.getOwner();
|
||||||
|
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
||||||
|
auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||||
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
BBArgValue.replaceAllUsesWith(toTensor);
|
||||||
|
spatCompute.getInputsMutable().erase(BBArgIndex);
|
||||||
|
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
||||||
|
auto BBArgIndex = argUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
||||||
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||||
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||||
|
BBArgValue.replaceAllUsesWith(toTensor);
|
||||||
|
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
||||||
|
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
||||||
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
rewriter.setInsertionPoint(argUser);
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
rewriter.startOpModification(argUser);
|
||||||
|
argUses.set(toTensor);
|
||||||
|
rewriter.finalizeOpModification(argUser);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
|
||||||
|
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
||||||
|
patterns.getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
10
src/PIM/Conversion/SpatialToPim/Patterns.hpp
Normal file
10
src/PIM/Conversion/SpatialToPim/Patterns.hpp
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,20 +1,26 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
#include "mlir/IR/BuiltinDialect.h"
|
#include "mlir/IR/BuiltinDialect.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallSet.h"
|
#include "llvm/ADT/SmallSet.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "llvm/Support/LogicalResult.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@@ -24,6 +30,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common.hpp"
|
||||||
|
#include "Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
@@ -53,7 +60,7 @@ struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>>
|
|||||||
void runOnOperation() final;
|
void runOnOperation() final;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SmallVector<Value> outputTensors;
|
SmallVector<std::function<Value(IRRewriter& rewriter, Location loc)>> outputTensors;
|
||||||
size_t coreId = 0;
|
size_t coreId = 0;
|
||||||
SmallVector<Operation*> operationsToRemove;
|
SmallVector<Operation*> operationsToRemove;
|
||||||
|
|
||||||
@@ -179,7 +186,22 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
||||||
auto inputType = cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
Value input = extractRowsOp.getInput();
|
||||||
|
RankedTensorType inputType;
|
||||||
|
if (auto tensorType = dyn_cast<RankedTensorType>(input.getType())) {
|
||||||
|
inputType = tensorType;
|
||||||
|
}
|
||||||
|
else if (auto memRefType = dyn_cast<MemRefType>(input.getType())) {
|
||||||
|
inputType = RankedTensorType::get(memRefType.getShape(), memRefType.getElementType());
|
||||||
|
rewriter.setInsertionPoint(extractRowsOp);
|
||||||
|
input = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, extractRowsOp.getLoc(), inputType, input, rewriter.getUnitAttr(), rewriter.getUnitAttr())
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
extractRowsOp.emitOpError("requires a ranked tensor or memref input during Spatial-to-PIM lowering");
|
||||||
|
return;
|
||||||
|
}
|
||||||
int64_t numCols = inputType.getDimSize(1);
|
int64_t numCols = inputType.getDimSize(1);
|
||||||
|
|
||||||
SmallVector<Value> replacements;
|
SmallVector<Value> replacements;
|
||||||
@@ -187,11 +209,16 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
|
|||||||
|
|
||||||
rewriter.setInsertionPoint(extractRowsOp);
|
rewriter.setInsertionPoint(extractRowsOp);
|
||||||
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
|
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
|
||||||
|
auto outputType = dyn_cast<RankedTensorType>(output.getType());
|
||||||
|
if (!outputType) {
|
||||||
|
extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering");
|
||||||
|
return;
|
||||||
|
}
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)), rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)), rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
auto rowSlice = tensor::ExtractSliceOp::create(
|
auto rowSlice = tensor::ExtractSliceOp::create(
|
||||||
rewriter, extractRowsOp.getLoc(), cast<RankedTensorType>(output.getType()), extractRowsOp.getInput(), offsets, sizes, strides);
|
rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
|
||||||
replacements.push_back(rowSlice.getResult());
|
replacements.push_back(rowSlice.getResult());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,6 +232,75 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
|
|||||||
rewriter.replaceOp(concatOp, concatenated);
|
rewriter.replaceOp(concatOp, concatenated);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||||
|
SmallVectorImpl<Operation*>& helperChain,
|
||||||
|
bool requireReturnUse = true) {
|
||||||
|
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||||
|
return failure();
|
||||||
|
if (requireReturnUse
|
||||||
|
&& (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin())))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Block& block = computeOp.getBody().front();
|
||||||
|
if (block.getNumArguments() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||||
|
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<Operation*> reverseChain;
|
||||||
|
Value currentValue = yieldOp.getOperands().front();
|
||||||
|
Value blockArg = block.getArgument(0);
|
||||||
|
|
||||||
|
while (currentValue != blockArg) {
|
||||||
|
Operation* definingOp = currentValue.getDefiningOp();
|
||||||
|
if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp))
|
||||||
|
return failure();
|
||||||
|
reverseChain.push_back(definingOp);
|
||||||
|
currentValue = definingOp->getOperand(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
|
||||||
|
if (!chainSet.contains(&op)
|
||||||
|
&& !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool inlineInputlessHelperComputeForBatchUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||||
|
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||||
|
return false;
|
||||||
|
if (!llvm::all_of(computeOp.getResult(0).getUsers(),
|
||||||
|
[](Operation* user) { return isa<spatial::SpatComputeBatch, pim::PimCoreBatchOp>(user); }))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
Block& block = computeOp.getBody().front();
|
||||||
|
if (block.getNumArguments() != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||||
|
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(computeOp);
|
||||||
|
IRMapping mapping;
|
||||||
|
for (Operation& op : block.without_terminator()) {
|
||||||
|
cloneMappedHelperOperands(&op, mapping, rewriter);
|
||||||
|
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
rewriter.setInsertionPointAfter(clonedOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
||||||
|
computeOp.getResult(0).replaceAllUsesWith(replacement);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
struct ReturnUseInfo {
|
struct ReturnUseInfo {
|
||||||
size_t returnIndex;
|
size_t returnIndex;
|
||||||
SmallVector<Operation*> helperChain;
|
SmallVector<Operation*> helperChain;
|
||||||
@@ -295,6 +391,20 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Operation*> helperChain;
|
SmallVector<Operation*> helperChain;
|
||||||
|
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
|
||||||
|
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
if (failed(collectHelperComputeChain(helperCompute, helperChain)))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
currentValue = helperCompute.getResult(0);
|
||||||
|
auto currentUses = currentValue.getUses();
|
||||||
|
if (rangeLength(currentUses) != 1)
|
||||||
|
return std::nullopt;
|
||||||
|
currentUser = currentUses.begin()->getOwner();
|
||||||
|
}
|
||||||
|
|
||||||
while (isChannelUseChainOp(currentUser)) {
|
while (isChannelUseChainOp(currentUser)) {
|
||||||
helperChain.push_back(currentUser);
|
helperChain.push_back(currentUser);
|
||||||
auto currentUses = currentUser->getResult(0).getUses();
|
auto currentUses = currentUser->getResult(0).getUses();
|
||||||
@@ -419,21 +529,22 @@ static void cloneHelperChain(Value sourceValue,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void emitHostCopy(IRRewriter& rewriter,
|
static Value emitHostCopy(IRRewriter& rewriter,
|
||||||
Location loc,
|
Location loc,
|
||||||
Value outputTensor,
|
Value outputTensor,
|
||||||
Value sourceValue,
|
Value sourceValue,
|
||||||
int32_t hostTargetOffset,
|
int32_t hostTargetOffset,
|
||||||
int32_t deviceSourceOffset,
|
int32_t deviceSourceOffset,
|
||||||
int32_t sizeInBytes) {
|
int32_t sizeInBytes) {
|
||||||
PimMemCopyDevToHostOp::create(rewriter,
|
return PimMemCopyDevToHostOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
outputTensor,
|
outputTensor,
|
||||||
sourceValue,
|
sourceValue,
|
||||||
rewriter.getI32IntegerAttr(hostTargetOffset),
|
rewriter.getI32IntegerAttr(hostTargetOffset),
|
||||||
rewriter.getI32IntegerAttr(deviceSourceOffset),
|
rewriter.getI32IntegerAttr(deviceSourceOffset),
|
||||||
rewriter.getI32IntegerAttr(sizeInBytes));
|
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||||
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::runOnOperation() {
|
void SpatialToPimPass::runOnOperation() {
|
||||||
@@ -458,12 +569,21 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
scf::SCFDialect,
|
scf::SCFDialect,
|
||||||
BuiltinDialect>();
|
BuiltinDialect>();
|
||||||
|
|
||||||
RewritePatternSet patterns(ctx);
|
{
|
||||||
populateWithGenerated(patterns);
|
RewritePatternSet patterns(ctx);
|
||||||
|
populateWithGenerated(patterns);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
RewritePatternSet patterns(ctx);
|
||||||
|
populateGlobalTensorToMemrefPatterns(patterns);
|
||||||
|
|
||||||
|
walkAndApplyPatterns(moduleOp, std::move(patterns));
|
||||||
}
|
}
|
||||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||||
|
|
||||||
@@ -489,7 +609,8 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
||||||
funcOp.walk([&](spatial::SpatChannelReceiveOp op) { receiveOps.push_back(op); });
|
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
||||||
|
receiveOps.push_back(op);
|
||||||
for (auto receiveOp : receiveOps) {
|
for (auto receiveOp : receiveOps) {
|
||||||
bool onlyPendingRemovalUsers = llvm::all_of(
|
bool onlyPendingRemovalUsers = llvm::all_of(
|
||||||
receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); });
|
receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); });
|
||||||
@@ -505,22 +626,26 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<spatial::SpatChannelReceiveManyOp> receiveManyOps;
|
SmallVector<spatial::SpatChannelReceiveManyOp> receiveManyOps;
|
||||||
funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { receiveManyOps.push_back(op); });
|
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveManyOp>())
|
||||||
|
receiveManyOps.push_back(op);
|
||||||
for (auto receiveManyOp : receiveManyOps)
|
for (auto receiveManyOp : receiveManyOps)
|
||||||
lowerChannelReceiveMany(receiveManyOp, rewriter);
|
lowerChannelReceiveMany(receiveManyOp, rewriter);
|
||||||
|
|
||||||
SmallVector<spatial::SpatChannelSendOp> sendOps;
|
SmallVector<spatial::SpatChannelSendOp> sendOps;
|
||||||
funcOp.walk([&](spatial::SpatChannelSendOp op) { sendOps.push_back(op); });
|
for (auto op : funcOp.getOps<spatial::SpatChannelSendOp>())
|
||||||
|
sendOps.push_back(op);
|
||||||
for (auto sendOp : sendOps)
|
for (auto sendOp : sendOps)
|
||||||
lowerChannelSend(sendOp, rewriter);
|
lowerChannelSend(sendOp, rewriter);
|
||||||
|
|
||||||
SmallVector<spatial::SpatChannelSendManyOp> sendManyOps;
|
SmallVector<spatial::SpatChannelSendManyOp> sendManyOps;
|
||||||
funcOp.walk([&](spatial::SpatChannelSendManyOp op) { sendManyOps.push_back(op); });
|
for (auto op : funcOp.getOps<spatial::SpatChannelSendManyOp>())
|
||||||
|
sendManyOps.push_back(op);
|
||||||
for (auto sendManyOp : sendManyOps)
|
for (auto sendManyOp : sendManyOps)
|
||||||
lowerChannelSendMany(sendManyOp, rewriter);
|
lowerChannelSendMany(sendManyOp, rewriter);
|
||||||
|
|
||||||
SmallVector<spatial::SpatExtractRowsOp> extractRowsOps;
|
SmallVector<spatial::SpatExtractRowsOp> extractRowsOps;
|
||||||
funcOp.walk([&](spatial::SpatExtractRowsOp op) { extractRowsOps.push_back(op); });
|
for (auto op : funcOp.getOps<spatial::SpatExtractRowsOp>())
|
||||||
|
extractRowsOps.push_back(op);
|
||||||
for (auto extractRowsOp : extractRowsOps)
|
for (auto extractRowsOp : extractRowsOps)
|
||||||
lowerExtractRows(extractRowsOp, rewriter);
|
lowerExtractRows(extractRowsOp, rewriter);
|
||||||
|
|
||||||
@@ -560,6 +685,36 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
assert(false && "tracked op removal reached a cycle or missed dependency");
|
assert(false && "tracked op removal reached a cycle or missed dependency");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<spatial::SpatConcatOp> remainingConcatOps;
|
||||||
|
funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); });
|
||||||
|
for (auto concatOp : remainingConcatOps)
|
||||||
|
lowerConcat(concatOp, rewriter);
|
||||||
|
|
||||||
|
SmallVector<spatial::SpatChannelReceiveOp> remainingReceiveOps;
|
||||||
|
funcOp.walk([&](spatial::SpatChannelReceiveOp op) { remainingReceiveOps.push_back(op); });
|
||||||
|
for (auto receiveOp : remainingReceiveOps)
|
||||||
|
lowerChannelReceive(receiveOp, rewriter);
|
||||||
|
|
||||||
|
SmallVector<spatial::SpatChannelReceiveManyOp> remainingReceiveManyOps;
|
||||||
|
funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { remainingReceiveManyOps.push_back(op); });
|
||||||
|
for (auto receiveManyOp : remainingReceiveManyOps)
|
||||||
|
lowerChannelReceiveMany(receiveManyOp, rewriter);
|
||||||
|
|
||||||
|
SmallVector<spatial::SpatChannelSendOp> remainingSendOps;
|
||||||
|
funcOp.walk([&](spatial::SpatChannelSendOp op) { remainingSendOps.push_back(op); });
|
||||||
|
for (auto sendOp : remainingSendOps)
|
||||||
|
lowerChannelSend(sendOp, rewriter);
|
||||||
|
|
||||||
|
SmallVector<spatial::SpatChannelSendManyOp> remainingSendManyOps;
|
||||||
|
funcOp.walk([&](spatial::SpatChannelSendManyOp op) { remainingSendManyOps.push_back(op); });
|
||||||
|
for (auto sendManyOp : remainingSendManyOps)
|
||||||
|
lowerChannelSendMany(sendManyOp, rewriter);
|
||||||
|
|
||||||
|
SmallVector<spatial::SpatExtractRowsOp> remainingExtractRowsOps;
|
||||||
|
funcOp.walk([&](spatial::SpatExtractRowsOp op) { remainingExtractRowsOps.push_back(op); });
|
||||||
|
for (auto extractRowsOp : remainingExtractRowsOps)
|
||||||
|
lowerExtractRows(extractRowsOp, rewriter);
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
bool hasSpatialOps = false;
|
bool hasSpatialOps = false;
|
||||||
moduleOp.walk([&](Operation* op) {
|
moduleOp.walk([&](Operation* op) {
|
||||||
@@ -579,6 +734,13 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||||
Location loc = computeOp->getLoc();
|
Location loc = computeOp->getLoc();
|
||||||
|
|
||||||
|
if (inlineInputlessHelperComputeForBatchUsers(computeOp, rewriter))
|
||||||
|
return;
|
||||||
|
|
||||||
|
SmallVector<Operation*> helperChain;
|
||||||
|
if (succeeded(collectHelperComputeChain(computeOp, helperChain)))
|
||||||
|
return;
|
||||||
|
|
||||||
auto& block = computeOp.getRegion().front();
|
auto& block = computeOp.getRegion().front();
|
||||||
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||||
|
|
||||||
@@ -616,9 +778,9 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
|||||||
|
|
||||||
auto storedType = cast<ShapedType>(storedValue.getType());
|
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||||
Value outputTensor = outputTensors[returnUse->returnIndex];
|
|
||||||
if (auto storedOp = storedValue.getDefiningOp())
|
if (auto storedOp = storedValue.getDefiningOp())
|
||||||
rewriter.setInsertionPointAfter(storedOp);
|
rewriter.setInsertionPointAfter(storedOp);
|
||||||
|
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||||
emitHostCopy(rewriter,
|
emitHostCopy(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor,
|
outputTensor,
|
||||||
@@ -637,8 +799,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
|||||||
if (isa<func::ReturnOp>(resultUser)) {
|
if (isa<func::ReturnOp>(resultUser)) {
|
||||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||||
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||||
Value outputTensor = outputTensors[resultIndexInReturn];
|
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
|
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||||
emitHostCopy(rewriter,
|
emitHostCopy(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor,
|
outputTensor,
|
||||||
@@ -654,13 +816,13 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
|
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
|
||||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex];
|
|
||||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
|
||||||
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
if (concatReturnUse->helperChain.empty()) {
|
if (concatReturnUse->helperChain.empty()) {
|
||||||
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
|
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||||
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||||
|
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||||
emitHostCopy(rewriter,
|
emitHostCopy(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor,
|
outputTensor,
|
||||||
@@ -671,7 +833,15 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto storedType = cast<RankedTensorType>(yieldValue.getType());
|
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
||||||
|
if (!storedType) {
|
||||||
|
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
|
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||||
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||||
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
||||||
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
||||||
for (auto [dim, idx] : llvm::enumerate(sourceIndices))
|
for (auto [dim, idx] : llvm::enumerate(sourceIndices))
|
||||||
@@ -701,19 +871,18 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
|||||||
|
|
||||||
auto scalarTensorType =
|
auto scalarTensorType =
|
||||||
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
|
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
|
||||||
auto elementSlice = tensor::ExtractSliceOp::create(
|
auto elementSlice = tensor::ExtractSliceOp::create(
|
||||||
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
|
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
|
||||||
rewriter.setInsertionPointAfter(elementSlice);
|
rewriter.setInsertionPointAfter(elementSlice);
|
||||||
|
|
||||||
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
||||||
emitHostCopy(rewriter,
|
outputTensor = emitHostCopy(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor,
|
outputTensor,
|
||||||
elementSlice.getResult(),
|
elementSlice.getResult(),
|
||||||
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
||||||
0,
|
0,
|
||||||
static_cast<int32_t>(elementSize));
|
static_cast<int32_t>(elementSize));
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -848,6 +1017,26 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
||||||
|
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||||
|
Operation* cloned = rewriter.clone(op, mapper);
|
||||||
|
auto clonedTensor = cloned->getResult(0);
|
||||||
|
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
||||||
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
||||||
|
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
outputBuffer.getType(),
|
||||||
|
outputBuffer,
|
||||||
|
clonedTensor,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
getTensorSizeInBytesAttr(rewriter, clonedTensor))
|
||||||
|
.getOutput();
|
||||||
|
mapper.map(toTensorOp.getResult(), copied);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (Value operand : op.getOperands()) {
|
for (Value operand : op.getOperands()) {
|
||||||
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
|
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
|
||||||
continue;
|
continue;
|
||||||
@@ -922,17 +1111,33 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
|||||||
|
|
||||||
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||||
outputTensors.reserve(returnOp->getNumOperands());
|
outputTensors.reserve(returnOp->getNumOperands());
|
||||||
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
||||||
for (auto returnValue : returnOp->getOperands()) {
|
|
||||||
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
|
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
|
||||||
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
||||||
assert(!hasWeightAlways(returnValueDefiningOp));
|
assert(!hasWeightAlways(returnValueDefiningOp));
|
||||||
outputTensors.push_back(returnValue);
|
outputTensors.push_back([returnValue](IRRewriter& rewriter, Location loc) -> Value { return returnValue; });
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto newOutputTensor =
|
auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(returnValue.getType());
|
||||||
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
|
auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
|
||||||
outputTensors.push_back(newOutputTensor);
|
|
||||||
|
std::string outputName = "output_" + std::to_string(index);
|
||||||
|
rewriter.setInsertionPoint(returnOp.getParentOp());
|
||||||
|
memref::GlobalOp::create(rewriter,
|
||||||
|
returnOp.getLoc(),
|
||||||
|
rewriter.getStringAttr(outputName),
|
||||||
|
rewriter.getStringAttr("private"),
|
||||||
|
TypeAttr::get(memRefType),
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{});
|
||||||
|
outputTensors.push_back(
|
||||||
|
[memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value {
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
return toTensor.getResult();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -940,11 +1145,11 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
|
|||||||
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
|
|
||||||
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
||||||
auto tensorType = cast<ShapedType>(valueToReplace.getType());
|
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
||||||
Type elementType = tensorType.getElementType();
|
Type elementType = tensorType.getElementType();
|
||||||
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
||||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
|
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
||||||
|
|
||||||
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
||||||
|
|
||||||
@@ -953,86 +1158,27 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
|||||||
loc,
|
loc,
|
||||||
tensorType,
|
tensorType,
|
||||||
deviceTensor,
|
deviceTensor,
|
||||||
hostTensor,
|
inputTensor,
|
||||||
rewriter.getI32IntegerAttr(0),
|
rewriter.getI32IntegerAttr(0),
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
||||||
|
|
||||||
rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult());
|
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
||||||
};
|
};
|
||||||
|
|
||||||
// Replace input tensors with memRefs
|
|
||||||
SmallVector<bufferization::ToTensorOp, 8> inputTensors;
|
|
||||||
for (size_t i = 0; i < funcOp.getNumArguments(); i++) {
|
|
||||||
BlockArgument tensorArg = funcOp.getArgument(i);
|
|
||||||
DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i);
|
|
||||||
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
|
|
||||||
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
|
|
||||||
|
|
||||||
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
|
|
||||||
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
|
|
||||||
BlockArgument memRefArg = funcOp.getArgument(i + 1);
|
|
||||||
|
|
||||||
Block& block = funcOp.getBody().front();
|
|
||||||
rewriter.setInsertionPoint(&block.front());
|
|
||||||
auto toTensorOp =
|
|
||||||
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
|
|
||||||
inputTensors.push_back(toTensorOp);
|
|
||||||
|
|
||||||
tensorArg.replaceAllUsesWith(toTensorOp);
|
|
||||||
if (failed(funcOp.eraseArgument(i)))
|
|
||||||
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
|
||||||
for (auto& op : funcOp.getBody().getOps())
|
for (auto& op : funcOp.getBody().getOps())
|
||||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||||
unsigned numComputeWeights = computeOp.getWeights().size();
|
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
|
||||||
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
|
continue;
|
||||||
TypedValue<TensorType> tensorSource;
|
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
|
||||||
int64_t elementsOffset = 0;
|
if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) {
|
||||||
|
assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute");
|
||||||
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin();
|
||||||
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
insertMemCopyHostToDev(toTensorOpValue, 0);
|
||||||
|
|
||||||
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
|
||||||
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
|
|
||||||
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
|
|
||||||
ArrayRef<int64_t> sliceStrides = sliceOp.getStaticStrides();
|
|
||||||
assert("Extracting slice non-contiguous in memory"
|
|
||||||
&& isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides));
|
|
||||||
|
|
||||||
for (size_t i = 0; i < sliceOffsets.size(); i++) {
|
|
||||||
int64_t partialOffset = sliceOffsets[i];
|
|
||||||
if (partialOffset != 0)
|
|
||||||
for (size_t j = i + 1; j < sourceShape.size(); j++)
|
|
||||||
partialOffset *= sourceShape[j];
|
|
||||||
elementsOffset += partialOffset;
|
|
||||||
}
|
|
||||||
|
|
||||||
computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource);
|
|
||||||
sliceOpsToRemove.insert(sliceOp);
|
|
||||||
}
|
}
|
||||||
else
|
|
||||||
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
|
|
||||||
|
|
||||||
// Values already produced inside the device-side graph must not be
|
|
||||||
// copied back through a host-to-device staging step here.
|
|
||||||
if (isa<spatial::SpatCompute, spatial::SpatChannelReceiveOp>(tensorSource.getDefiningOp()))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
|
|
||||||
insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto sliceOp : sliceOpsToRemove)
|
|
||||||
if (sliceOp->getUses().empty())
|
|
||||||
rewriter.eraseOp(sliceOp);
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1050,7 +1196,7 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
|||||||
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
||||||
Operation* onlyUser = *op->getUsers().begin();
|
Operation* onlyUser = *op->getUsers().begin();
|
||||||
isExclusivelyOwnedByReturnChain =
|
isExclusivelyOwnedByReturnChain =
|
||||||
isa<func::ReturnOp, tensor::ConcatOp>(onlyUser) || isChannelUseChainOp(onlyUser);
|
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatCompute>(onlyUser) || isChannelUseChainOp(onlyUser);
|
||||||
}
|
}
|
||||||
if (!isExclusivelyOwnedByReturnChain)
|
if (!isExclusivelyOwnedByReturnChain)
|
||||||
return;
|
return;
|
||||||
@@ -1062,6 +1208,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||||
|
markOpToRemove(computeOp);
|
||||||
|
for (Value input : computeOp.getInputs())
|
||||||
|
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||||
markOpToRemove(concatOp);
|
markOpToRemove(concatOp);
|
||||||
for (Value operand : concatOp.getOperands())
|
for (Value operand : concatOp.getOperands())
|
||||||
@@ -1070,12 +1223,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
|||||||
};
|
};
|
||||||
|
|
||||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||||
|
auto loc = returnOp.getLoc();
|
||||||
for (auto it : llvm::enumerate(originalOperands)) {
|
for (auto it : llvm::enumerate(originalOperands)) {
|
||||||
size_t orderWithinReturn = it.index();
|
size_t orderWithinReturn = it.index();
|
||||||
Operation* returnOperand = it.value().getDefiningOp();
|
Operation* returnOperand = it.value().getDefiningOp();
|
||||||
|
rewriter.setInsertionPoint(returnOp);
|
||||||
rewriter.modifyOpInPlace(returnOp,
|
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
|
||||||
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
|
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
||||||
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def PimTensor :
|
|||||||
// Execution
|
// Execution
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def PimCoreOp : PimOp<"core", [SingleBlock]> {
|
def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
||||||
let summary = "Execute a block on a PIM core";
|
let summary = "Execute a block on a PIM core";
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|||||||
@@ -3,12 +3,17 @@
|
|||||||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/Threading.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Compiler/PimCodeGen.hpp"
|
#include "Compiler/PimCodeGen.hpp"
|
||||||
#include "Dialect/Pim/PimOps.hpp"
|
#include "Dialect/Pim/PimOps.hpp"
|
||||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
@@ -40,14 +45,44 @@ private:
|
|||||||
|
|
||||||
void PimBufferizationPass::runOnOperation() {
|
void PimBufferizationPass::runOnOperation() {
|
||||||
auto moduleOp = getOperation();
|
auto moduleOp = getOperation();
|
||||||
|
// Refactor this into a function
|
||||||
|
{
|
||||||
|
auto funcOp = getPimEntryFunc(moduleOp);
|
||||||
|
|
||||||
// One-Shot-Bufferization
|
auto coreOps = llvm::to_vector(funcOp->getOps<pim::PimCoreOp>());
|
||||||
bufferization::OneShotBufferizationOptions options;
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
options.allowUnknownOps = true;
|
// failableParallelForEach will run the lambda in parallel and stop if any thread fails
|
||||||
bufferization::BufferizationState state;
|
LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](pim::PimCoreOp coreOp) {
|
||||||
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
|
// Again, allocate state LOCALLY per thread/function
|
||||||
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
|
bufferization::OneShotBufferizationOptions options;
|
||||||
signalPassFailure();
|
options.allowUnknownOps = true;
|
||||||
|
bufferization::BufferizationState state;
|
||||||
|
if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) {
|
||||||
|
coreOp.emitError("Failed to bufferize PIM and Spatial ops");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
|
||||||
|
if (failed(result)) {
|
||||||
|
moduleOp.emitError("Failed to bufferize-parallel PIM and Spatial ops");
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
funcOp->walk([&](bufferization::ToTensorOp toTensorOp) {
|
||||||
|
if (llvm::isa_and_present<pim::PimCoreOp>(toTensorOp->getParentOp()))
|
||||||
|
toTensorOp->setAttr("restrict", UnitAttr::get(ctx));
|
||||||
|
});
|
||||||
|
|
||||||
|
// One-Shot-Bufferization
|
||||||
|
bufferization::OneShotBufferizationOptions options;
|
||||||
|
options.allowUnknownOps = true;
|
||||||
|
bufferization::BufferizationState state;
|
||||||
|
|
||||||
|
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
|
||||||
|
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MLIRContext* ctx = moduleOp.getContext();
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
@@ -57,7 +92,18 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet patterns(ctx);
|
||||||
populateWithGenerated(patterns);
|
populateWithGenerated(patterns);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
// Only convert memref.copy → pim.memcp inside pim.core / pim.core_batch bodies.
|
||||||
|
// Host-level copies (e.g. from split/slice ops) must remain as memref.copy for CPU lowering.
|
||||||
|
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
||||||
|
bool hasFailed = false;
|
||||||
|
moduleOp.walk<WalkOrder::PreOrder>([&](Operation* op) {
|
||||||
|
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
|
||||||
|
return WalkResult::advance();
|
||||||
|
if (failed(applyPartialConversion(op, target, frozenPatterns)))
|
||||||
|
hasFailed = true;
|
||||||
|
return WalkResult::skip();
|
||||||
|
});
|
||||||
|
if (hasFailed) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,10 +116,9 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
|||||||
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
|
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
|
||||||
|
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPoint(coreOp);
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(mapOp);
|
rewriter.setInsertionPoint(mapOp);
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||||
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
|
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
|
||||||
pim::PimMemCopyOp::create(rewriter,
|
pim::PimMemCopyOp::create(rewriter,
|
||||||
mapOp.getLoc(),
|
mapOp.getLoc(),
|
||||||
@@ -258,9 +257,18 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
|||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
// Look through an optional pim.memcp_hd to find the source get_global.
|
||||||
|
// This occurs when the constant was staged into device memory before transposing.
|
||||||
|
pim::PimMemCopyHostToDevOp memcpHd;
|
||||||
auto sourceGetGlobal = transposeOp.getInput().getDefiningOp<memref::GetGlobalOp>();
|
auto sourceGetGlobal = transposeOp.getInput().getDefiningOp<memref::GetGlobalOp>();
|
||||||
if (!sourceGetGlobal)
|
if (!sourceGetGlobal) {
|
||||||
return failure();
|
memcpHd = transposeOp.getInput().getDefiningOp<pim::PimMemCopyHostToDevOp>();
|
||||||
|
if (!memcpHd)
|
||||||
|
return failure();
|
||||||
|
sourceGetGlobal = memcpHd.getHostSource().getDefiningOp<memref::GetGlobalOp>();
|
||||||
|
if (!sourceGetGlobal)
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
|
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
|
||||||
if (!moduleOp)
|
if (!moduleOp)
|
||||||
@@ -298,13 +306,26 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
|||||||
|
|
||||||
bool isAlwaysWeight =
|
bool isAlwaysWeight =
|
||||||
!transposeOp->getUsers().empty()
|
!transposeOp->getUsers().empty()
|
||||||
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) {
|
||||||
|
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||||
|
});
|
||||||
if (isAlwaysWeight) {
|
if (isAlwaysWeight) {
|
||||||
markWeightAlways(newGlobal);
|
markWeightAlways(newGlobal);
|
||||||
markWeightAlways(newGetGlobal);
|
markWeightAlways(newGetGlobal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto outputAllocOp = transposeOp.getOutputBuffer().getDefiningOp<memref::AllocOp>();
|
||||||
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
|
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
|
||||||
|
|
||||||
|
if (memcpHd && memcpHd.use_empty()) {
|
||||||
|
auto deviceAllocOp = memcpHd.getDeviceTarget().getDefiningOp<memref::AllocOp>();
|
||||||
|
rewriter.eraseOp(memcpHd);
|
||||||
|
if (deviceAllocOp && deviceAllocOp->use_empty())
|
||||||
|
rewriter.eraseOp(deviceAllocOp);
|
||||||
|
}
|
||||||
|
if (outputAllocOp && outputAllocOp->use_empty())
|
||||||
|
rewriter.eraseOp(outputAllocOp);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -341,18 +362,25 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isa<pim::PimCoreOp>(user))
|
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
|
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
|
||||||
return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
return llvm::all_of(castOp->getUsers(), [](Operation* user) {
|
||||||
|
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||||
|
});
|
||||||
})) {
|
})) {
|
||||||
allLiveUsersAreCoreOps = false;
|
allLiveUsersAreCoreOps = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
|
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
|
||||||
return isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp, memref::CastOp, pim::PimCoreOp>(user);
|
return isa<linalg::MapOp,
|
||||||
|
memref::SubViewOp,
|
||||||
|
memref::DeallocOp,
|
||||||
|
memref::CastOp,
|
||||||
|
pim::PimCoreOp,
|
||||||
|
pim::PimCoreBatchOp>(user);
|
||||||
})) {
|
})) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
@@ -389,6 +417,83 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto allocOp = copyOp.getTarget().getDefiningOp<memref::AllocOp>();
|
||||||
|
if (!allocOp)
|
||||||
|
return failure();
|
||||||
|
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||||
|
if (!allocType || !allocType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
|
||||||
|
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
|
||||||
|
|
||||||
|
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||||
|
if (failed(denseAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
DenseElementsAttr foldedAttr;
|
||||||
|
if (succeeded(srcSubview)) {
|
||||||
|
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||||
|
return failure();
|
||||||
|
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
||||||
|
if (failed(staticOffsets))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
|
||||||
|
if (failed(maybeFoldedAttr))
|
||||||
|
return failure();
|
||||||
|
foldedAttr = *maybeFoldedAttr;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||||
|
if (resultTensorType != denseAttr->getType())
|
||||||
|
return failure();
|
||||||
|
foldedAttr = *denseAttr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool allLiveUsersAreCores = true;
|
||||||
|
for (Operation* user : allocOp->getUsers()) {
|
||||||
|
if (user == copyOp)
|
||||||
|
continue;
|
||||||
|
if (isa<memref::DeallocOp>(user))
|
||||||
|
continue;
|
||||||
|
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
|
||||||
|
continue;
|
||||||
|
if (isa<memref::SubViewOp>(user)) {
|
||||||
|
allLiveUsersAreCores = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_host_copy");
|
||||||
|
if (allLiveUsersAreCores)
|
||||||
|
markWeightAlways(newGlobal);
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(allocOp);
|
||||||
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||||
|
if (allLiveUsersAreCores)
|
||||||
|
markWeightAlways(newGetGlobal);
|
||||||
|
|
||||||
|
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
|
||||||
|
rewriter.eraseOp(copyOp);
|
||||||
|
if (allocOp.use_empty())
|
||||||
|
rewriter.eraseOp(allocOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
@@ -443,7 +548,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|||||||
continue;
|
continue;
|
||||||
if (isa<memref::DeallocOp>(user))
|
if (isa<memref::DeallocOp>(user))
|
||||||
continue;
|
continue;
|
||||||
if (isa<pim::PimCoreOp>(user))
|
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
|
||||||
continue;
|
continue;
|
||||||
if (isa<memref::SubViewOp>(user)) {
|
if (isa<memref::SubViewOp>(user)) {
|
||||||
allLiveUsersAreCores = false;
|
allLiveUsersAreCores = false;
|
||||||
@@ -473,7 +578,11 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|||||||
|
|
||||||
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
||||||
patterns
|
patterns
|
||||||
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, FoldConstantMemCpPattern>(
|
.add<FoldConstantTransposePattern,
|
||||||
|
FoldConstantAllocPattern,
|
||||||
|
FoldConstantCoreMapPattern,
|
||||||
|
FoldConstantHostCopyPattern,
|
||||||
|
FoldConstantMemCpPattern>(
|
||||||
patterns.getContext());
|
patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,27 @@ static bool isAddressOnlyHostOp(Operation* op) {
|
|||||||
memref::SubViewOp,
|
memref::SubViewOp,
|
||||||
memref::CastOp,
|
memref::CastOp,
|
||||||
memref::CollapseShapeOp,
|
memref::CollapseShapeOp,
|
||||||
memref::ExpandShapeOp>(op);
|
memref::ExpandShapeOp,
|
||||||
|
memref::CopyOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Looser than isCodegenAddressableValue: follows view ops without requiring contiguity.
|
||||||
|
// Used for memref.copy operands which may be non-contiguous subviews.
|
||||||
|
static bool isBaseAddressableValue(Value value) {
|
||||||
|
while (true) {
|
||||||
|
if (isa<BlockArgument>(value))
|
||||||
|
return true;
|
||||||
|
Operation* defOp = value.getDefiningOp();
|
||||||
|
if (!defOp)
|
||||||
|
return false;
|
||||||
|
if (isa<memref::AllocOp, memref::GetGlobalOp>(defOp))
|
||||||
|
return true;
|
||||||
|
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) { value = subview.getSource(); continue; }
|
||||||
|
if (auto cast = dyn_cast<memref::CastOp>(defOp)) { value = cast.getSource(); continue; }
|
||||||
|
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) { value = collapse.getSrc(); continue; }
|
||||||
|
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) { value = expand.getSrc(); continue; }
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isCodegenAddressableValue(Value value) {
|
static bool isCodegenAddressableValue(Value value) {
|
||||||
@@ -183,6 +203,13 @@ private:
|
|||||||
return verifyAddressOnlySource(op, collapseOp.getSrc());
|
return verifyAddressOnlySource(op, collapseOp.getSrc());
|
||||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
||||||
return verifyAddressOnlySource(op, expandOp.getSrc());
|
return verifyAddressOnlySource(op, expandOp.getSrc());
|
||||||
|
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
|
||||||
|
if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) {
|
||||||
|
op->emitOpError("depends on a value that is not backed by addressable storage");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class ValidationResult:
|
|||||||
|
|
||||||
|
|
||||||
class ProgressReporter:
|
class ProgressReporter:
|
||||||
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
|
def __init__(self, total_models, stages_per_model=STAGE_COUNT, enabled=None):
|
||||||
self.total_models = total_models
|
self.total_models = total_models
|
||||||
self.stages_per_model = stages_per_model
|
self.stages_per_model = stages_per_model
|
||||||
self.total_steps = max(1, total_models * stages_per_model)
|
self.total_steps = max(1, total_models * stages_per_model)
|
||||||
@@ -45,7 +45,7 @@ class ProgressReporter:
|
|||||||
self.passed_models = 0
|
self.passed_models = 0
|
||||||
self.failed_models = 0
|
self.failed_models = 0
|
||||||
self.current_label = ""
|
self.current_label = ""
|
||||||
self.enabled = True
|
self.enabled = sys.stdout.isatty() if enabled is None else enabled
|
||||||
self.columns = shutil.get_terminal_size((100, 20)).columns
|
self.columns = shutil.get_terminal_size((100, 20)).columns
|
||||||
self.suspended = False
|
self.suspended = False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user