add constant folding and verification pass for pim host operations

better validation scripts output
big refactors
This commit is contained in:
NiccoloN
2026-03-20 12:08:12 +01:00
parent 4e50e056e3
commit 6e1de865bb
64 changed files with 1364 additions and 2265 deletions

View File

@@ -20,6 +20,8 @@ add_onnx_mlir_library(OMPIMAccel
Pass/CountInstructionPass.cpp Pass/CountInstructionPass.cpp
Pass/EmitPimJsonPass.cpp Pass/EmitPimJsonPass.cpp
Pass/MessagePass.cpp Pass/MessagePass.cpp
Pass/PimFoldHostConstantsPass.cpp
Pass/PimHostVerificationPass.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
@@ -43,4 +45,5 @@ add_onnx_mlir_library(OMPIMAccel
OMSpatialToGraphviz OMSpatialToGraphviz
OMSpatialToPIM OMSpatialToPIM
OMPIMCommon OMPIMCommon
MLIRTensorInferTypeOpInterfaceImpl
) )

View File

@@ -5,6 +5,7 @@
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
using namespace mlir; using namespace mlir;
@@ -30,6 +31,60 @@ void dumpModule(ModuleOp moduleOp, const std::string& name) {
file.close(); file.close();
} }
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
if (!moduleOp)
return failure();
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ")
<< entryPoints.size();
return failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return failure();
}
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
return mainGraphFunc;
SmallVector<func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
}
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return failure();
}
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PIM_WEIGHT_ALWAYS_ATTR_NAME) != nullptr; }
void markWeightAlways(Operation* op) {
assert(op && "expected valid op");
op->setAttr(PIM_WEIGHT_ALWAYS_ATTR_NAME, UnitAttr::get(op->getContext()));
}
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
}
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) { FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>(); auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();

View File

@@ -1,6 +1,8 @@
#pragma once #pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
@@ -9,6 +11,7 @@
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate"; const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate";
inline constexpr llvm::StringRef PIM_WEIGHT_ALWAYS_ATTR_NAME = "weightAlways";
namespace onnx_mlir { namespace onnx_mlir {
@@ -18,6 +21,14 @@ void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name); void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op);
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
llvm::FailureOr<mlir::Operation*> llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter); getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);

View File

@@ -13,11 +13,12 @@
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include "Common/PIMCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp" #include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp" #include "src/Compiler/CompilerUtils.hpp"
@@ -49,8 +50,8 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others // Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants; SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!getGlobalOp->hasAttr("weightAlways")) { if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName()); auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
auto iter = globalConstants.find(globalMemrefOp); auto iter = globalConstants.find(globalMemrefOp);
if (iter == globalConstants.end()) if (iter == globalConstants.end())
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp); globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
@@ -81,7 +82,7 @@ MemEntry PimMemory::getMemEntry(mlir::Value value) const {
return iter->second; return iter->second;
} }
PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) { PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second; return deviceMem.try_emplace(id, memEntriesMap).first->second;
} }
@@ -112,10 +113,33 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
} }
value = source; value = source;
} }
else if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
}
else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
}
else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
}
else else
break; break;
} }
return memEntriesMap.at(value).address + offset;
auto iter = memEntriesMap.find(value);
if (iter == memEntriesMap.end()) {
errs() << "Missing mem entry for value: ";
value.print(errs());
errs() << "\n";
if (auto* definingOp = value.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Missing mem entry");
}
return iter->second.address + offset;
} }
json::Object PimCodeGen::createEmptyOffset() { json::Object PimCodeGen::createEmptyOffset() {
@@ -348,6 +372,55 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co
} }
} }
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
auto srcAddr = memory.getValueAddress(transposeOp.getData());
auto dstAddr = memory.getValueAddress(transposeOp.getOutBuf());
auto srcType = cast<ShapedType>(transposeOp.getData().getType());
auto srcShape = srcType.getShape();
size_t rank = srcShape.size();
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
size_t totalElements = srcType.getNumElements();
// Read permutation and compute its inverse
SmallVector<int64_t> perm =
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
SmallVector<int64_t> permInv(rank);
for (size_t i = 0; i < rank; i++)
permInv[perm[i]] = i;
// Destination shape: dstShape[i] = srcShape[perm[i]]
SmallVector<int64_t> dstShape(rank);
for (size_t i = 0; i < rank; i++)
dstShape[i] = srcShape[perm[i]];
// Row-major strides for source and destination
SmallVector<size_t> srcStrides(rank, 1);
SmallVector<size_t> dstStrides(rank, 1);
for (int64_t i = rank - 2; i >= 0; i--) {
srcStrides[i] = srcStrides[i + 1] * srcShape[i + 1];
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
}
// Emit element-by-element copy with transposed addressing
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
// Decompose flat source index into multi-dimensional index
SmallVector<size_t> srcIdx(rank);
size_t remaining = srcFlat;
for (size_t d = 0; d < rank; d++) {
srcIdx[d] = remaining / srcStrides[d];
remaining %= srcStrides[d];
}
// Compute flat destination index: dstIdx[d] = srcIdx[permInv[d]]
size_t dstFlat = 0;
for (size_t d = 0; d < rank; d++)
dstFlat += srcIdx[permInv[d]] * dstStrides[d];
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
}
}
size_t getMatrixSize(ShapedType matrixShape) { size_t getMatrixSize(ShapedType matrixShape) {
if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4) if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4)
assert(false && "Unsupported matrix shape"); assert(false && "Unsupported matrix shape");
@@ -378,9 +451,9 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0); std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (getGlobalOp->hasAttr("weightAlways")) if (hasWeightAlways(getGlobalOp))
return; return;
auto globalOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName()); auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) if (!globalOp)
return; return;
auto initialValue = globalOp.getInitialValue(); auto initialValue = globalOp.getInitialValue();
@@ -416,7 +489,7 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) { static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
size_t processedOperations = 0; size_t processedOperations = 0;
for (auto& op : coreOp.getBody().front()) { for (auto& op : coreOp.getBody().front()) {
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp>(op)) if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp, memref::ExpandShapeOp, memref::CollapseShapeOp>(op))
continue; continue;
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op)) if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
@@ -435,6 +508,8 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false); coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
else if (auto applyFiltersOp = dyn_cast<pim::PimApplyFiltersOp>(op)) else if (auto applyFiltersOp = dyn_cast<pim::PimApplyFiltersOp>(op))
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp); coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp);
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op)) else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op))
coreCodeGen.codeGenVAddOp(vaddOp); coreCodeGen.codeGenVAddOp(vaddOp);
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op)) else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op))
@@ -475,7 +550,7 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
continue; continue;
} }
auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(moduleOp, getGlobalOp.getNameAttr()); auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) { if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex)); coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
weightIndex++; weightIndex++;
@@ -589,9 +664,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
} }
} }
auto funcOps = moduleOp.getOps<func::FuncOp>(); auto entryFunc = getPimEntryFunc(moduleOp);
assert(!funcOps.empty() && "No function found in the module"); if (failed(entryFunc))
auto funcOp = *funcOps.begin(); return CompilerFailure;
auto funcOp = *entryFunc;
PimAcceleratorMemory memory; PimAcceleratorMemory memory;
memory.hostMem.allocateHost(moduleOp, funcOp); memory.hostMem.allocateHost(moduleOp, funcOp);

View File

@@ -5,7 +5,7 @@
#include "Common/ValueMap.hpp" #include "Common/ValueMap.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -49,7 +49,7 @@ public:
PimAcceleratorMemory() PimAcceleratorMemory()
: hostMem(memEntriesMap) {} : hostMem(memEntriesMap) {}
PimMemory getOrCreateDeviceMem(size_t id); PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value) const; size_t getValueAddress(mlir::Value value) const;
}; };
@@ -95,6 +95,7 @@ public:
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const; void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const;
void codeGenVReluOp(pim::PimVReluOp vreluOp) const; void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const; void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
}; };
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName); OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);

View File

@@ -25,7 +25,6 @@ extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<bool> pimOnlyCodegen; extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> useExperimentalConvImpl; extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<bool> exportCrossbarWeights;
extern llvm::cl::opt<size_t> crossbarSize; extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore; extern llvm::cl::opt<size_t> crossbarCountInCore;

View File

@@ -2,7 +2,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp" #include "src/Compiler/CompilerUtils.hpp"
@@ -46,6 +46,10 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
} }
if (pimEmissionTarget >= EmitPimCodegen) { if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimFoldHostConstantsPass());
pm.addPass(createMessagePass("Pim host constants folded"));
pm.addPass(createPimHostVerificationPass());
pm.addPass(createMessagePass("Pim host verified"));
pm.addPass(createEmitPimJsonPass()); pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim json code emitted")); pm.addPass(createMessagePass("Pim json code emitted"));

View File

@@ -3,21 +3,15 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen) add_public_tablegen_target(ONNXToSpatialIncGen)
add_onnx_mlir_library(OMONNXToSpatial add_onnx_mlir_library(OMONNXToSpatial
Math/Gemm.hpp
Math/Gemm.cpp Math/Gemm.cpp
Math/Conv.hpp
Math/Conv.cpp Math/Conv.cpp
Math/ExperimentalConv.cpp
Math/ExperimentalGemm.cpp
NN/Pooling.cpp NN/Pooling.cpp
NN/ExperimentalPooling.cpp
NN/ReduceMean.cpp NN/ReduceMean.cpp
Tensor/ONNXConcatToTensorConcat.cpp Tensor/ONNXConcatToTensorConcat.cpp
Tensor/RemoveUnusedHelperOps.cpp Tensor/RemoveUnusedHelperOps.cpp
Utils/SpatialReducer.cpp Utils/SpatialReducer.cpp
Utils/WeightSubdivider.cpp Utils/WeightSubdivider.cpp
Utils/AnnotateReplication.cpp Utils/AnnotateReplication.cpp
ONNXToSpatialPass.hpp
ONNXToSpatialPass.cpp ONNXToSpatialPass.cpp
ONNXToSpatialCommon.cpp ONNXToSpatialCommon.cpp

View File

@@ -242,6 +242,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
return success(); return success();
} }
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); } void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -18,6 +18,6 @@ struct ConvToGemm : mlir::OpConversionPattern<mlir::ONNXConvOp> {
mlir::ConversionPatternRewriter& rewriter) const override; mlir::ConversionPatternRewriter& rewriter) const override;
}; };
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,583 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
#include <cstddef>
#include <memory>
#include <unordered_map>
#include <vector>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
namespace onnx_mlir {
// NOTE:
// This might be useful to re-implement this considering for loops.
// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
/**
* @brief A momentary representation of a core, to be used within the tiling of
* a convolution operation.
*/
class Core {
public:
Core(const size_t coreId, ConversionPatternRewriter& rewriter)
: coreId(coreId), rewriter(rewriter) {}
/**
* @brief Add a MVM operation to the core.
*
* @param inputTile The input tile to the MVM operation.
* @param xbarIndex The index of the crossbar weight to use.
* @param outputTileId The id of the output tile.
* @param mvmOutType The result's shape.
* @return Value The result of the MVM operation.
*/
Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
// Use the inputTile as the reference location for the MVM operation.
Location loc = inputTile.getLoc();
// Move the insertion point to the end of the block.
rewriter.setInsertionPointToEnd(block.get());
// Add the inputTile to the block arguments, and to the operands.
Value operand = operandMap.lookupOrNull(inputTile);
if (not operand) {
operand = block->addArgument(inputTile.getType(), loc);
operands.push_back(inputTile);
operandMap.map(inputTile, operand);
}
// TODO: Compute the output type using the matrix, and check if `mvmOutType`
// is correct.
// Construct the MVM operation
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(loc, mvmOutType, xbarIndex, operand);
// Since we are within the same core and no computation can happen in
// paralllel, we can just apply a linear reduction in case we have multiple
// MVM operations for the same outputTile.
auto lastMVM = outputTileToMVM.find(outputTileId);
// If an entry for this outputTile already exists, apply reduction.
if (lastMVM != outputTileToMVM.end()) {
// MVM results should have the same type for reduction.
assert(lastMVM->second.getType() == result.getType());
result = rewriter.create<spatial::SpatVAddOp>(loc, mvmOutType, lastMVM->second, result);
}
outputTileToMVM[outputTileId] = result;
return result;
}
/**
* @brief Mark a result as remappable, and return a shared pointer to it.
*
* This function marks a result as remappable, and returns a shared pointer to
* it. We need to keep track of these values to generate the YieldOp at a
* later stage.
*
* @param result A result to track, for later remapping.
* @return shared_ptr<Value> A shared pointer to the result.
*/
shared_ptr<Value> makeResultRemappable(Value result) {
// Verify that the result is present in the block.
assert(result.getDefiningOp()->getBlock() == block.get());
shared_ptr<mlir::Value> remappableResult = make_shared<Value>(result);
resultsToRemap.push_back(remappableResult);
results.push_back(result);
return remappableResult;
}
/**
* @brief Add a remappable operand to the core, to merge partial results
* inter-core.
*
* @param remappableOperand The operand to add.
* @return Value The block argument representing the operand.
*/
Value addRemappableOperand(std::shared_ptr<Value> operand) {
// Check that the operand is not already there.
assert(not operandMap.contains(*operand));
Value argument = block->addArgument(operand->getType(), operand->getLoc());
remappableOperands.push_back(operand);
return argument;
}
/**
* @brief Generate a spatial::SpatWeightedCompute operation from the core.
*
* @param loc The location of the operation.
* @return spatial::SpatWeightedCompute
*/
spatial::SpatWeightedCompute createWComputeOp(Location loc) {
// Get the shape of the results.
SmallVector<Type> resultTypes;
for (const auto& value : results)
resultTypes.push_back(value.getType());
// Create the WComputeOp, with non-remappable operands only.
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(loc, resultTypes, xbarWeights, operands);
// Add the body to the WComputeOp.
Block* releasedBlock = block.release();
wcomputeOp.getBody().push_back(releasedBlock);
// Add the `yieldOp` at the end, with the results.
rewriter.setInsertionPointToEnd(releasedBlock);
rewriter.create<spatial::SpatYieldOp>(loc, results);
return wcomputeOp;
}
/**
* @brief Remap the results to the WComputeOp results.
*/
void remapResults() {
// Remap all the results to the WComputeOp results.
assert(resultsToRemap.size() == wcomputeOp->getNumResults());
for (size_t i = 0; i < resultsToRemap.size(); i++)
*resultsToRemap[i] = wcomputeOp.getResult(i);
}
void addRemappedOperands() {
// Insert the remappableOperands (which were remapped in
// `addRemappableOperand` of another Core)
for (auto remappedValue : remappableOperands)
wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue);
// Update the wcomputeOp operandSegmentSize
incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast<int>(remappableOperands.size()));
}
size_t addXbarWeight(Value weight) {
assert(!isXbarsFull());
xbarWeights.push_back(weight);
return xbarWeights.size() - 1;
}
bool isXbarsFull() {
assert(xbarWeights.size() <= crossbarCountInCore);
return xbarWeights.size() == crossbarCountInCore;
}
bool isCoreEmpty() { return block->empty(); }
void dump() {
// Print the coreId
llvm::outs() << "Core " << coreId << ":\n";
// Print the weights
llvm::outs() << "Xbar Weights:\n";
for (auto weight : xbarWeights)
weight.dump();
// Print the operands
llvm::outs() << "Operands:\n";
for (auto operand : operands)
llvm::outs() << operand << "\n";
// Dump the body block
for (auto& op : block->getOperations())
op.dump();
// Print the results
llvm::outs() << "Results:\n";
for (auto result : results)
llvm::outs() << result << "\n";
}
const size_t coreId;
private:
ConversionPatternRewriter& rewriter;
// Should these be set<Value> instead? But I need to keep the order
vector<Value> operands;
vector<std::shared_ptr<Value>> remappableOperands;
vector<Value> results;
vector<std::shared_ptr<Value>> resultsToRemap;
// Maps from input tiles to the block operand
IRMapping operandMap;
// Map from outputTileId to MVM operation producing it
unordered_map<size_t, Value> outputTileToMVM;
vector<Value> xbarWeights;
unique_ptr<mlir::Block> block = make_unique<Block>();
spatial::SpatWeightedCompute wcomputeOp;
};
struct ConvToManyGemms : public OpConversionPattern<ONNXConvOp> {
ConvToManyGemms(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
struct Producer_t {
Value value;
shared_ptr<Core> core;
};
LogicalResult
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType());
ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType());
ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType());
ShapedType yShape = mlir::cast<ShapedType>(conv.getY().getType());
size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y;
unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y);
auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(conv, padUnpackError.value());
// TODO: Pad value at beginning and end of each dimension could be
// different. We should handle this case.
// MapOperations mapOperation = MapOperations::None;
//
// // If we have just one user, and it is an activation funcion (or more in
// // general a mapping operation) just inline it in the computeOps
// auto firstUserOp = *conv->getUsers().begin();
// if (conv->hasOneUse()) {
// mapOperation = mlirOpToMapOperationEnum(firstUserOp);
//
// if (mapOperation == MapOperations::ONNXSoftmaxOp) {
// return rewriter.notifyMatchFailure(
// conv, "Softmax not supported as activation for convolutions.");
// }
// }
size_t input_h = GET_IMAGE_HEIGHT(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
size_t krn_w = GET_KERNEL_WIDTH(wShape);
Location loc = conv.getLoc();
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize;
// Tile the input tensor
// Input tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: inputTiles[channelTile][x][y]
// Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
inputTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt = resolveImgInputTiles(
convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
if (resolveErrorOpt.has_value())
return rewriter.notifyMatchFailure(conv, *resolveErrorOpt);
SmallVector<OpFoldResult> strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult> {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
// Tile the weight tensor
// Weight tiles need to be indexed by:
// a. Filter Tile
// b. Channel Tile
// c. Kernel `x` position
// d. Kernel `y` position
// For example: weightTiles[filterTile][channelTile][x][y]
// Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH)
SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles(
outputTileCount,
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
sizes = {rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
for (size_t i = 0; i < outputTileCount; i++) {
if (i == outputTileCount - 1 && outputTileRemainder != 0)
sizes[0] = rewriter.getIndexAttr(outputTileRemainder);
sizes[1] = rewriter.getIndexAttr(crossbarSize);
offsets[0] = rewriter.getIndexAttr(i * crossbarSize);
for (size_t j = 0; j < inputTileCount; j++) {
if (j == inputTileCount - 1 && inputTileRemainder != 0)
sizes[1] = rewriter.getIndexAttr(inputTileRemainder);
for (size_t x = 0; x < krn_w; x++) {
for (size_t y = 0; y < krn_h; y++) {
offsets[1] = rewriter.getIndexAttr(j * crossbarSize);
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
weightTiles[i][j][x][y] =
rewriter.create<tensor::ExtractSliceOp>(loc, convAdaptor.getW(), offsets, sizes, strides);
}
}
}
}
/* Distribute the computation among many compute cores
* Try to compute in-core the computation for each output tile, and reduce
* over as few cores as possible
*/
// Tile the output tensor
// Output tiles need to be indexed by:
// a. Filter Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: outputTiles[filterTile][x][y]
// Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH)
SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles(
outputTileCount,
SmallVector<SmallVector<shared_ptr<Value>>>(output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
size_t replicationFactor;
if (!conv->hasAttr(REPLICATION_ATTR_NAME))
replicationFactor = 1;
else
replicationFactor = conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
// producers[outTile][out_x][out_y][producerIndex]
vector<vector<vector<vector<Producer_t>>>> producers = vector<vector<vector<vector<Producer_t>>>>(
outputTileCount,
vector<vector<vector<Producer_t>>>(output_w, vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
// Schedule in cores
size_t coreId = 0;
vector<shared_ptr<Core>> curCores(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++)
curCores[i] = make_shared<Core>(coreId++, rewriter);
vector<shared_ptr<Core>> cores;
const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor);
for (size_t krn_x = 0; krn_x < krn_h; krn_x++) {
for (size_t krn_y = 0; krn_y < krn_w; krn_y++) {
RankedTensorType mvmOutType =
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1}, bShape.getElementType());
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
if (outTile == outputTileCount - 1 && outputTileRemainder != 0)
mvmOutType = mvmOutType.clone({1, static_cast<long>(outputTileRemainder), 1, 1});
for (size_t inTile = 0; inTile < inputTileCount; inTile++) {
vector<size_t> xbarIndexes(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++)
xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]);
size_t out_x = 0;
for (size_t in_x = 0; in_x < input_w; in_x += stride_x) {
size_t out_y = 0;
// I use `replicationFactor` cores. I divide the input_w into
// `replicationFactor` slices, and each slice is distributed to a
// core. `coreIndex` is the index of the core that will be used
// for this slice
size_t coreIndex = in_x / replicationSliceSize;
assert(coreIndex < replicationFactor);
for (size_t in_y = 0; in_y < input_h; in_y += stride_y) {
// Adjust the input based on the kernel
int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x;
int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y;
// Check if we are within the input image
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) {
out_y++;
continue;
}
size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y;
auto mvm = curCores[coreIndex]->addMVM(
inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType);
producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]});
out_y++;
}
out_x++;
}
// Computations for these crossbars are done, check if the cores
// crossbars are fully used. If full, swap with new core
for (size_t i = 0; i < replicationFactor; i++) {
if (curCores[i]->isXbarsFull()) {
cores.emplace_back(std::move(curCores[i]));
curCores[i] = make_shared<Core>(coreId++, rewriter);
}
}
}
}
}
}
for (auto& curCore : curCores)
if (curCore->isCoreEmpty() == false)
cores.emplace_back(std::move(curCore));
curCores.clear();
// Now, do the reduction of each output pixel tile
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
for (size_t out_x = 0; out_x < output_w; out_x++) {
for (size_t out_y = 0; out_y < output_h; out_y++) {
// First, check if some producers are within the same core. If this is
// true, `Core::addMVM` have already done the reduction within-core.
// This means that we only need to consider the last producer for that
// core.
std::unordered_map<size_t, Producer_t> withinCoreReducedProducers;
for (auto producer : producers[outTile][out_x][out_y])
withinCoreReducedProducers[producer.core->coreId] = producer;
// Now, we need to apply inter-core reduction
// Base case with one producer
if (withinCoreReducedProducers.size() == 1) {
// TODO: Add the bias and apply mapping (if present)
auto singleProducer = withinCoreReducedProducers.begin()->second;
// Use last producer as the final result
auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue;
continue;
}
// TODO: This is a linear reduction, not a tree reduction. We can do
// better: a tree reduction would make more computations happen in
// parallel.
Producer_t lastProducer = withinCoreReducedProducers.begin()->second;
auto it = withinCoreReducedProducers.begin();
it++;
while (it != withinCoreReducedProducers.end()) {
Producer_t curProducer = it->second;
shared_ptr<Core> core1;
shared_ptr<Core> core2;
Value core1Value;
Value core2Value;
auto lastProducerCoreId = lastProducer.core->coreId;
auto curProducerCoreId = curProducer.core->coreId;
assert(lastProducerCoreId != curProducerCoreId
&& "We should have already applied within-core reduction, how "
"could we have same cores here?");
// Sort the cores by coreId
if (curProducerCoreId < lastProducerCoreId) {
core1 = curProducer.core;
core1Value = curProducer.value;
core2 = lastProducer.core;
core2Value = lastProducer.value;
}
else {
core1 = lastProducer.core;
core1Value = lastProducer.value;
core2 = curProducer.core;
core2Value = curProducer.value;
}
auto newCoreRes = core1->makeResultRemappable(core1Value);
auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes);
rewriter.setInsertionPointAfterValue(core2Value);
Value vaddRes = rewriter.create<spatial::SpatVAddOp>(
core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg);
lastProducer = {vaddRes, core2};
it++;
}
// TODO: Add the bias and apply mapping (if present)
// Use last producer as the final result
auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue;
}
}
}
// Now, we need to turn the cores into a spatial::SpatWeightedCompute.
rewriter.setInsertionPointAfter(conv);
spatial::SpatWeightedCompute lastWComputeOp;
for (auto& core : cores) {
lastWComputeOp = core->createWComputeOp(loc);
core->remapResults();
rewriter.setInsertionPointAfter(lastWComputeOp);
}
for (auto& core : cores)
core->addRemappedOperands();
// Set the insertion point after the last WComputeOp.
rewriter.setInsertionPointAfter(lastWComputeOp);
SmallVector<Value> tilesToConcat;
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize);
for (size_t outX = 0; outX < output_h; outX++)
for (size_t outY = 0; outY < output_w; outY++)
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
tilesToConcat.push_back(*outputTiles[outTile][outX][outY]);
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(loc, conv.getY().getType(), tilesToConcat);
// Value outputImage =
// createImgConcatOp(outputTiles, rewriter, loc, Y.getType());
// If no mapping (activation) was applied, just replace ConvOp
// if (mapOperation == MapOperations::None) {
// rewriter.replaceOp(conv, outputImage);
// } else {
// // If mapping was applied, erase ConvOp and replace the mapping op
// rewriter.eraseOp(conv);
// rewriter.replaceOp(firstUserOp, outputImage);
// }
return success();
}
};
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ConvToManyGemms>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,400 +0,0 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cstddef>
#include <unistd.h>
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
namespace onnx_mlir {
/**
* @brief A pattern to tile the convolution operation into a series of compute
* units, each one of which applies filters to a subset of the input
* tensor. Results are also reduced and concatenated to form the final
* output tensor.
*/
struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ExperimentalONNXConvOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
// --------------------------------- //
// --- READ OPERATION PARAMETERS --- //
// --------------------------------- //
// To get each crossbar's weights, we need to slice the weights tensor.
// - Along the input tiles.
// - Along the output tiles.
// - Along the filter x position.
// - Along the filter y position.
ShapedType inputType = cast<ShapedType>(convAdaptor.getX().getType());
ShapedType outputType = cast<ShapedType>(conv.getY().getType());
ShapedType weightsType = cast<ShapedType>(convAdaptor.getW().getType());
// TODO: Address bigger batches.
assert(GET_IMAGE_N(inputType) == 1
&& "Batch size must be 1"
"for convolution.");
// TODO: Address replication.
assert(coresCount.getValue() == -1 && "Replication is not yet supported for convolution.");
// TODO: Address bias addition.
ldiv_t inputTileCount = div(GET_IMAGE_CHANNEL(inputType), crossbarSize);
ldiv_t outputTileCount = div(GET_IMAGE_CHANNEL(outputType), crossbarSize);
size_t kernelWidth = GET_KERNEL_WIDTH(weightsType);
size_t kernelHeight = GET_KERNEL_HEIGHT(weightsType);
// Assert that the kernel is square.
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
// -------------------------------- //
// --- SLICE THE WEIGHTS TENSOR --- //
// -------------------------------- //
// The core idea of this stage is classifying the weights by input and
// output tile. This is because we want the applyFilters operations to be
// tile agnostic, to keep the subsequent lowering stages as simple as
// possible. This data structure does this weight classification:
// - The outer map is indexed by input tile.
// - The inner map is indexed by output tile.
// - The SmallVector contains the weights for the filter.
map<long, map<long, SmallVector<Value>>> weightsGroups;
// During all slicing operations within this stage, we'll use the same
// strides for all dimensions.
SmallVector<OpFoldResult> slicingStrides(4, rewriter.getIndexAttr(1));
ldiv_t itc = inputTileCount;
ldiv_t otc = outputTileCount;
// - Slicing along the input tiles.
// - Slicing along the output tiles.
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
// The loop above also sets the crossbar's used width and height,
// checking if we're at the last crossbar and if it's incomplete.
long outputTile = ot;
long inputTile = it;
// Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
// - Slicing along the filter x position.
// - Slicing along the filter y position.
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(filterX),
/* 3 */ rewriter.getIndexAttr(filterY)};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
}
}
}
}
// TODO: Tree reduction for compute reduction should be implemented.
// -------------------------------- //
// --- CREATE ALL COMPUTE UNITS --- //
// -------------------------------- //
// Keep track of input slicing operations to avoid duplication across
// all compute units (global slices).
map<long, Value> globalSlices;
// Keep track of all partial compute results.
map<long, Value> globalPartialResults;
// Use a weight subdivider to extract groups of weights for each compute
// unit. We'll keep extracting groups until no more weights are left.
WeightSubdivider weightSubdivider(weightsGroups);
while (!weightSubdivider.isEmpty()) {
// -------------------------------- //
// --- BEGIN A NEW COMPUTE UNIT --- //
// -------------------------------- //
// Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands;
// ------------------------------ //
// --- SLICE THE INPUT TENSOR --- //
// ------------------------------ //
// Note each tile's index in the compute unit arguments.
map<long, size_t> inputTileIndices;
map<long, size_t> outputTileIndices;
map<long, size_t> reductionTileIndices; // Incoming partial results.
// Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) {
for (Value weight : group.weights)
computeWeights.push_back(weight);
// There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end())
continue;
// We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a
// new one.
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
computeOperands.push_back(globalSlices[group.inputTile]);
localSlices[group.inputTile] = globalSlices[group.inputTile];
continue;
}
// Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(0),
/* 3 */ rewriter.getIndexAttr(0)};
// Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
/* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, slicingStrides);
computeOperands.push_back(extractSliceOp);
// Update slicing maps.
globalSlices[group.inputTile] = extractSliceOp;
localSlices[group.inputTile] = extractSliceOp;
// Update the input tile index.
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
}
// ------------------------------- //
// --- PREPARE THE OUTPUT TYPE --- //
// ------------------------------- //
// Fill the compute output's type by looking at the output tiles.
SmallVector<Type> computeOutputType;
for (TaggedWeights group : weightsGroups) {
// There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
continue;
// Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
}
// Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
// TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed.
/* 3 */ GET_IMAGE_HEIGHT(outputType)};
auto elementType = dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
}
// ----------------------------- //
// --- FILL THE COMPUTE UNIT --- //
// ----------------------------- //
// Create the compute unit.
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
conv.getLoc(), computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands.
Block* block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands)
block->addArgument(operand.getType(), conv->getLoc());
// Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices)
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
// Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group.
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation.
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights.
SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i)
weightIndices.push_back(group.startingCrossbarIndex + i);
SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos;
for (auto weight : group.weights) {
// Assert that the weight is an extract_slice operation.
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
assert(extractSliceOp && "Weight is not an extract_slice operation.");
// Get the filter x and y positions from the extract_slice operation.
auto offsets = extractSliceOp.getStaticOffsets();
xKerPos.push_back(offsets[2]);
yKerPos.push_back(offsets[3]);
}
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
conv.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
// Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(
conv.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
}
// Update the partial results map.
localPartialResults[group.outputTile] = result;
}
// Add a yield operation to the block by concatenating the partial
// results.
SmallVector<Value> applyFiltersResults;
for (size_t i = 0; i < computeOutputType.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
// Get that tile's partial result and add it to the list.
applyFiltersResults.push_back(localPartialResults[outputTile]);
}
// Create the yield operation with the given results.
rewriter.create<spatial::SpatYieldOp>(conv.getLoc(), applyFiltersResults);
// Update the global partial results map.
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
globalPartialResults[outputTile] = currentCompute.getResult(i);
}
// Move the rewrite cursor out of the block.
rewriter.setInsertionPointAfter(currentCompute);
}
// ------------------------------ //
// --- CONCATENATE THE OUTPUT --- //
// ------------------------------ //
// Turn the values into a SmallVector.
SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
outputValues.push_back(globalPartialResults[i]);
// Assert that the number of output values is correct.
assert(outputValues.size() > 0 && "No output values were generated for the convolution.");
// If the conv's user is a ReLU...
if (conv->hasOneUse()) {
Operation* user = *conv->getUsers().begin();
if (auto relu = dyn_cast<ONNXReluOp>(user)) {
// ...then we can just replace the ReLU with the concatenation.
rewriter.replaceOp(relu, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
// And erase the convolution.
rewriter.eraseOp(conv);
return success();
}
}
// Return the final output.
rewriter.replaceOp(conv, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
return success();
}
};
/**
* @brief Populate the tiling pattern for a convolution operation.
*
* @param patterns The pattern set to populate.
* @param ctx The MLIR context.
*/
void populateExperimentalTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ExperimentalONNXConvOpTile>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,365 +0,0 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdlib>
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
namespace onnx_mlir {
struct ExperimentalGemmConversionPattern : public OpConversionPattern<ONNXGemmOp> {
ExperimentalGemmConversionPattern(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
// --------------------------------- //
// --- READ OPERATION PARAMETERS --- //
// --------------------------------- //
// To get each crossbar's weights, we need to slice the weights tensor.
// - Along the input tiles.
// - Along the output tiles.
// - Along the filter x position.
// - Along the filter y position.
ShapedType inputType = cast<ShapedType>(adaptor.getA().getType());
ShapedType outputType = cast<ShapedType>(gemmOp.getY().getType());
ShapedType matrixType = cast<ShapedType>(adaptor.getB().getType());
// TODO: Address bigger batches.
assert(inputType.getShape()[0] == 1 && "Only batch size of 1 is supported for GEMM.");
// TODO: Address replication.
assert(coresCount.getValue() == -1 && "Replication is not yet supported for GEMM.");
// TODO: Address bias addition.
assert(inputType.getShape()[1] == matrixType.getShape()[0] && "Input tile size must match the matrix's row size.");
ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize);
ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize);
size_t kernelWidth = 1;
size_t kernelHeight = 1;
// Assert that the kernel is square.
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
// -------------------------------- //
// --- SLICE THE WEIGHTS TENSOR --- //
// -------------------------------- //
// The core idea of this stage is classifying the weights by input and
// output tile. This is because we want the applyFilters operations to be
// tile agnostic, to keep the subsequent lowering stages as simple as
// possible. This data structure does this weight classification:
// - The outer map is indexed by input tile.
// - The inner map is indexed by output tile.
// - The SmallVector contains the weights for the filter.
map<long, map<long, SmallVector<Value>>> weightsGroups;
// During all slicing operations within this stage, we'll use the same
// strides for all dimensions.
SmallVector<OpFoldResult> slicingStrides(2, rewriter.getIndexAttr(1));
ldiv_t itc = inputTileCount;
ldiv_t otc = outputTileCount;
// - Slicing along the input tiles.
// - Slicing along the output tiles.
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
// The loop above also sets the crossbar's used width and height,
// checking if we're at the last crossbar and if it's incomplete.
long outputTile = ot;
long inputTile = it;
// Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ /* rewriter.getIndexAttr(1), */
/* 3 */ /* rewriter.getIndexAttr(1) */};
// - Slicing along the filter x position.
// - Slicing along the filter y position.
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(filterX), */
/* 3 */ /* rewriter.getIndexAttr(filterY) */};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
}
}
}
}
// TODO: Tree reduction for compute reduction should be implemented.
// -------------------------------- //
// --- CREATE ALL COMPUTE UNITS --- //
// -------------------------------- //
// Keep track of input slicing operations to avoid duplication across
// all compute units (global slices).
map<long, Value> globalSlices;
// Keep track of all partial compute results.
map<long, Value> globalPartialResults;
// Use a weight subdivider to extract groups of weights for each compute
// unit. We'll keep extracting groups until no more weights are left.
WeightSubdivider weightSubdivider(weightsGroups);
while (!weightSubdivider.isEmpty()) {
// -------------------------------- //
// --- BEGIN A NEW COMPUTE UNIT --- //
// -------------------------------- //
// Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands;
// ------------------------------ //
// --- SLICE THE INPUT TENSOR --- //
// ------------------------------ //
// Note each tile's index in the compute unit arguments.
map<long, size_t> inputTileIndices;
map<long, size_t> outputTileIndices;
map<long, size_t> reductionTileIndices; // Incoming partial results.
// Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) {
for (Value weight : group.weights)
computeWeights.push_back(weight);
// There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end())
continue;
// We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a
// new one.
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
computeOperands.push_back(globalSlices[group.inputTile]);
localSlices[group.inputTile] = globalSlices[group.inputTile];
continue;
}
// Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(0), */
/* 3 */ /* rewriter.getIndexAttr(0) */};
// Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
/* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
gemmOp.getLoc(), adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
computeOperands.push_back(extractSliceOp);
// Update slicing maps.
globalSlices[group.inputTile] = extractSliceOp;
localSlices[group.inputTile] = extractSliceOp;
// Update the input tile index.
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
}
// ------------------------------- //
// --- PREPARE THE OUTPUT TYPE --- //
// ------------------------------- //
// Fill the compute output's type by looking at the output tiles.
SmallVector<Type> computeOutputType;
for (TaggedWeights group : weightsGroups) {
// There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
continue;
// Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
}
// Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
// TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed.
/* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType()).getElementType();
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
}
// ----------------------------- //
// --- FILL THE COMPUTE UNIT --- //
// ----------------------------- //
// Create the compute unit.
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
gemmOp.getLoc(), computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands.
Block* block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands)
block->addArgument(operand.getType(), gemmOp->getLoc());
// Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices)
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
// Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group.
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation.
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights.
SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i)
weightIndices.push_back(group.startingCrossbarIndex + i);
SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos;
for (auto weight : group.weights) {
// Assert that the weight is an extract_slice operation.
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
assert(extractSliceOp && "Weight is not an extract_slice operation.");
// Get the filter x and y positions from the extract_slice operation.
xKerPos.push_back(0);
yKerPos.push_back(0);
}
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
gemmOp.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
// Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(
gemmOp.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
}
// Update the partial results map.
localPartialResults[group.outputTile] = result;
}
// Add a yield operation to the block by concatenating the partial
// results.
SmallVector<Value> applyFiltersResults;
for (size_t i = 0; i < computeOutputType.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
// Get that tile's partial result and add it to the list.
applyFiltersResults.push_back(localPartialResults[outputTile]);
}
// Create the yield operation with the given results.
rewriter.create<spatial::SpatYieldOp>(gemmOp.getLoc(), applyFiltersResults);
// Update the global partial results map.
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
globalPartialResults[outputTile] = currentCompute.getResult(i);
}
// Move the rewrite cursor out of the block.
rewriter.setInsertionPointAfter(currentCompute);
}
// ------------------------------ //
// --- CONCATENATE THE OUTPUT --- //
// ------------------------------ //
// Turn the values into a SmallVector.
SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
outputValues.push_back(globalPartialResults[i]);
// Assert that the number of output values is correct.
assert(outputValues.size() > 0 && "No output values were generated for the GEMM operation.");
// Return the final output.
rewriter.replaceOp(gemmOp, rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
return success();
}
};
void populateGemmToConvConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ExperimentalGemmConversionPattern>(ctx);
}
} // namespace onnx_mlir

View File

@@ -10,7 +10,6 @@
#include <cassert> #include <cassert>
#include "Gemm.hpp"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
@@ -20,6 +19,38 @@
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace {
constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
GemvToSpatialCompute(MLIRContext* ctx)
: OpConversionPattern(ctx, 1) {}
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
private:
static Value resolveONNXExpOpFromUseChain(Value startValue);
static LogicalResult softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
Value& softmaxChannel,
ConversionPatternRewriter& rewriter,
SpatialReducer& reducer,
ONNXGemmOp& gemmOp,
Location& loc);
};
} // namespace
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor, ONNXGemmOpAdaptor gemmOpAdaptor,

View File

@@ -1,54 +0,0 @@
#pragma once
#include "Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
constexpr mlir::StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
struct GemmToManyGemv : mlir::OpConversionPattern<mlir::ONNXGemmOp> {
GemmToManyGemv(mlir::MLIRContext* ctx)
: OpConversionPattern(ctx, 2) {}
mlir::LogicalResult matchAndRewrite(mlir::ONNXGemmOp gemmOp,
mlir::ONNXGemmOpAdaptor gemmOpAdaptor,
mlir::ConversionPatternRewriter& rewriter) const override;
};
struct GemvToSpatialCompute : mlir::OpConversionPattern<mlir::ONNXGemmOp> {
GemvToSpatialCompute(mlir::MLIRContext* ctx)
: OpConversionPattern(ctx, 1) {}
llvm::LogicalResult matchAndRewrite(mlir::ONNXGemmOp gemmOp,
mlir::ONNXGemmOpAdaptor gemmOpAdaptor,
mlir::ConversionPatternRewriter& rewriter) const override;
private:
/**
* Resolves the ONNXExpOp from the use chain of the given start value.
*
* This function traverses the use chain of the start value until it finds an
* ONNXExpOp. It returns the value of the ONNXExpOp.
*
* @param startValue The starting value of the use chain.
* @return The value of the ONNXExpOp found in the use chain.
*/
static mlir::Value resolveONNXExpOpFromUseChain(mlir::Value startValue);
// Softmax is a special case, as it requires another reduction after the
// first one. In the cores, `applyReducePattern` already applied
// f(x) = exp(x) to each tile. This mean that now we just need to
// reduce-sum these tiles, and then divide each tile by the reduced sum,
// which is propagated back to the cores via a broadcast channel.
static llvm::LogicalResult softmaxReductionApplication(llvm::SmallVector<OpAndResNum>& outputOpsAndResNums,
Value& softmaxChannel,
ConversionPatternRewriter& rewriter,
SpatialReducer& reducer,
ONNXGemmOp& gemmOp,
Location& loc);
};
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx);
} // namespace onnx_mlir

View File

@@ -1,300 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
template <typename PoolOp>
bool hasPostProcessExperimentalPoolingWindow() {
return false;
}
template <>
bool hasPostProcessExperimentalPoolingWindow<ONNXAveragePoolOp>() {
return true;
}
template <typename PoolOp>
Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr;
}
template <>
Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
Location loc,
ONNXAveragePoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
}
template <typename ReductionOp>
Value reduceInputTiles(SmallVector<Value>& inputTiles, ConversionPatternRewriter& rewriter) {
if (inputTiles.size() == 1)
return inputTiles[0];
if (inputTiles.size() == 2) {
return rewriter.create<spatial::SpatVMaxOp>(
inputTiles[0].getLoc(), inputTiles[0].getType(), inputTiles[0], inputTiles[1]);
}
SmallVector<Value> left(inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2);
SmallVector<Value> right(inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
Value leftReduced = reduceInputTiles<ReductionOp>(left, rewriter);
Value rightReduced = reduceInputTiles<ReductionOp>(right, rewriter);
return rewriter.create<ReductionOp>(inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
}
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
ExperimentalPoolingBaseConverter(MLIRContext* ctx)
: OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult();
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET")
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
size_t pad_x, pad_y;
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
Location loc = poolOp.getLoc();
size_t input_h = GET_IMAGE_HEIGHT(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
ldiv_t tileCount = std::div(GET_IMAGE_CHANNEL(xShape), crossbarSize);
// Assert that the input is a tensor.ConcatOp.
auto concat = X.getDefiningOp<tensor::ConcatOp>();
if (!concat)
return rewriter.notifyMatchFailure(poolOp, "Expected input to be a tensor.ConcatOp");
// Create a [channel_tile][x][y] array to store the input tiles.
std::map<long, std::map<long, std::map<long, Value>>> inputTiles;
// For each argument of the tensor.ConcatOp, resolve the input tiles.
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) {
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = {/* 0 */ rewriter.getIndexAttr(0),
/* 1 */ rewriter.getIndexAttr(0),
/* 2 */ rewriter.getIndexAttr(x),
/* 3 */ rewriter.getIndexAttr(y)};
SmallVector<OpFoldResult> sizes = {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
// Get the concat's operand that we want to slice.
Value concatInput = concat.getOperand(it);
Value slicedTile = rewriter.create<tensor::ExtractSliceOp>(loc, concatInput, offsets, sizes, strides);
inputTiles[it][x][y] = slicedTile;
}
}
}
// Prepare the shape of the compute's output.
ldiv_t itc = tileCount;
SmallVector<Type> outputTileTypes;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */
cast<RankedTensorType>(inputTiles[it][0][0].getType()).getShape()[1],
/* 2 */ 1,
/* 3 */ 1};
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
outputTileTypes.push_back(RankedTensorType::get(outputShapeArray, elementType));
}
}
}
// Create a plain value list of the input tiles.
SmallVector<Value> inputTilesList;
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x)
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it)
inputTilesList.push_back(inputTiles[it][y][x]);
}
// Create a single compute to calculate the output.
auto computeOp =
rewriter.create<spatial::SpatWeightedCompute>(loc, outputTileTypes, SmallVector<Value>(), inputTilesList);
// Create a new block for the compute unit and add the operands.
Block* block = rewriter.createBlock(&computeOp.getRegion());
// Fill the block arguments and keep a reference to them.
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> inputTilesArgs;
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
inputTilesArgs[it][y][x] = block->addArgument(computeOp->getOperand(tileIndex).getType(), loc);
}
}
}
// Begin writing in the block.
rewriter.setInsertionPointToStart(block);
// Go through all pooling blocks.
SmallVector<Value> outputTiles;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
size_t start_x = x * stride_x;
size_t start_y = y * stride_y;
size_t end_x = std::min(start_x + krn_w, input_w);
size_t end_y = std::min(start_y + krn_h, input_h);
SmallVector<Value> inputTilesToReduce;
for (size_t ky = start_y; ky < end_y; ++ky)
for (size_t kx = start_x; kx < end_x; ++kx)
inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]);
auto reduceResult = reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
// If the reduce op is add, we need to divide the result by the
// number of elements in the pooling window.
if (hasPostProcessExperimentalPoolingWindow<PoolOp>()) {
// Add a spat.const before the computeOp.
rewriter.setInsertionPoint(computeOp);
auto divisorValue =
rewriter.create<spatial::SpatConstantOp>(loc,
RankedTensorType::get({1}, rewriter.getF32Type()),
rewriter.getI64IntegerAttr(krn_w * krn_h),
rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfter(reduceResult.getDefiningOp());
reduceResult =
rewriter.create<spatial::SpatVSDivOp>(loc, reduceResult.getType(), reduceResult, divisorValue);
}
outputTiles.push_back(reduceResult);
}
}
}
// Create a YieldOp to return the output tiles.
rewriter.create<spatial::SpatYieldOp>(loc, outputTiles);
// Set the rewrite cursor right after the computeOp.
rewriter.setInsertionPointAfter(computeOp);
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> computeOutput;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
computeOutput[it][y][x] = computeOp.getResult(tileIndex);
}
}
}
// We'll now create spat.img.concat ops to concatenate the output tiles.
SmallVector<Value> outputTilesList;
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<Value> imgConcatTiles;
for (size_t y = 0; y < output_h; ++y)
for (size_t x = 0; x < output_w; ++x)
imgConcatTiles.push_back(computeOutput[it][y][x]);
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ (long) tilingSize,
/* 2 */ (long) output_w,
/* 3 */ (long) output_h};
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(
loc, RankedTensorType::get(outputShapeArray, elementType), imgConcatTiles));
}
// Create a new tensor.ConcatOp to concatenate the output tiles.
Value outputTensor = rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
rewriter.replaceOp(poolOp, outputTensor);
return success();
}
};
void populateExperimentalPoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<
ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(
ctx);
}
} // namespace onnx_mlir

View File

@@ -26,8 +26,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce, Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
std::function<Value(const Value&, const Value&)> reduce, std::function<Value(const Value&, const Value&)> reduce,
@@ -225,12 +223,12 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
Location loc = poolOp.getLoc(); Location loc = poolOp.getLoc();
size_t input_h = GET_IMAGE_HEIGHT(xShape); size_t input_h = getImageHeight(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape); size_t input_w = getImageWidth(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape); size_t output_h = getImageHeight(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape); size_t output_w = getImageWidth(yShape);
size_t channelTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize; size_t channelTileRest = getImageChannel(xShape) % crossbarSize;
// 1: Tile the input tensor // 1: Tile the input tensor
// Input tiles need to be indexed by: // Input tiles need to be indexed by:

View File

@@ -13,9 +13,7 @@ def onnxToArithConstantOp : Pat<
(Arith_ConstantOp $value) (Arith_ConstantOp $value)
>; >;
//===----------------------------------------------------------------------===//
// ONNXMatMulOp to ONNXGemmOp patterns // ONNXMatMulOp to ONNXGemmOp patterns
//===----------------------------------------------------------------------===//
def matMulAddToGemmPattern : Pat< def matMulAddToGemmPattern : Pat<
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C), (ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
@@ -39,9 +37,7 @@ def matMulToGemmPattern : Pat<
) )
>; >;
//===----------------------------------------------------------------------===//
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern // ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
//===----------------------------------------------------------------------===//
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single // This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
// ONNXConvOp with a bias. // ONNXConvOp with a bias.
@@ -55,9 +51,7 @@ def convAddToConvWithBiasPatternRight : Pat<
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides) (ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
>; >;
//===----------------------------------------------------------------------===//
// Operation to ignore (i.e. remove) // Operation to ignore (i.e. remove)
//===----------------------------------------------------------------------===//
def replaceWithOperationOfValue : NativeCodeCall<"$0">; def replaceWithOperationOfValue : NativeCodeCall<"$0">;

View File

@@ -180,10 +180,10 @@ void tileImageTensorByChannel(Value imageTensor,
ConversionPatternRewriter& rewriter) { ConversionPatternRewriter& rewriter) {
ShapedType imageShape = mlir::cast<ShapedType>(imageTensor.getType()); ShapedType imageShape = mlir::cast<ShapedType>(imageTensor.getType());
size_t input_h = GET_IMAGE_HEIGHT(imageShape); size_t input_h = getImageHeight(imageShape);
size_t input_w = GET_IMAGE_WIDTH(imageShape); size_t input_w = getImageWidth(imageShape);
size_t tileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(imageShape), tileSize); size_t tileCount = ceilIntegerDivide(getImageChannel(imageShape), tileSize);
size_t tileRest = GET_IMAGE_CHANNEL(imageShape) % tileSize; size_t tileRest = getImageChannel(imageShape) % tileSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1)); SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));

View File

@@ -9,24 +9,55 @@
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
#include <cassert>
#include <cstddef>
#include <optional>
#include <type_traits>
#include <utility>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#define DEFINE_MAP_OP(opname) opname, #define DEFINE_MAP_OP(opname) opname,
#define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2)
#define GET_IMAGE_HEIGHT(shapedType) shapedType.getDimSize(3)
#define GET_IMAGE_CHANNEL(shapedType) shapedType.getDimSize(1)
#define GET_IMAGE_N(shapedType) shapedType.getDimSize(0)
#define GET_KERNEL_WIDTH(shapedType) shapedType.getDimSize(2)
#define GET_KERNEL_HEIGHT(shapedType) shapedType.getDimSize(3)
#define GET_FILTER_COUNT(shapedType) shapedType.getDimSize(0)
using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
const StringRef REPLICATION_ATTR_NAME = "replication_factor"; template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
inline constexpr mlir::StringRef REPLICATION_ATTR_NAME = "replication_factor";
using HSliceId = size_t; using HSliceId = size_t;
using CoreId = size_t; using CoreId = size_t;
@@ -58,51 +89,64 @@ constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
} }
template <class T> template <class T>
bool isVectorShape(const ArrayRef<T> shape) { bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1); return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
} }
template <class T> template <class T>
bool isMatrixShape(const ArrayRef<T> shape) { bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2; return shape.size() == 2;
} }
template <class T> template <class T>
bool isHVectorShape(const ArrayRef<T> shape) { bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1; return shape.size() == 2 && shape[0] == 1;
} }
template <class T> template <class T>
bool isVVectorShape(const ArrayRef<T> shape) { bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1; return shape.size() == 2 && shape[1] == 1;
} }
template <class T> template <class T>
T getVectorLength(const ArrayRef<T> shape) { T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape)); assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1]; return shape[0] != 1 ? shape[0] : shape[1];
} }
inline auto getTensorShape(const Value tensor) { return cast<RankedTensorType>(tensor.getType()).getShape(); } inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
SmallVector<Value> sliceTensor( llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc); size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
SmallVector<Value> llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc); int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
DenseMap<CoreId, SmallVector<Value>> llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc); const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix( llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc); tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
tensor::SplatOp mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc); int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter); mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input); mlir::Value createMapOperation(mlir::PatternRewriter& rewriter, MapOperations mapOp, const mlir::Value& input);
/** /**
* Unpacks an optional pair vector into two size_t values. * Unpacks an optional pair vector into two size_t values.
@@ -126,7 +170,8 @@ void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t
* *
* @return llvm::Optional<llvm::Twine> The error message if the pads are invalid * @return llvm::Optional<llvm::Twine> The error message if the pads are invalid
*/ */
std::optional<Twine> unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y); std::optional<llvm::Twine>
unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y);
/** /**
* Tiles the image tensor by channel. * Tiles the image tensor by channel.
@@ -140,10 +185,10 @@ std::optional<Twine> unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> val
* @param tileSize The size of each tile. * @param tileSize The size of each tile.
* @param rewriter The ConversionPatternRewriter used for creating operations. * @param rewriter The ConversionPatternRewriter used for creating operations.
*/ */
void tileImageTensorByChannel(Value imageTensor, void tileImageTensorByChannel(mlir::Value imageTensor,
SmallVector<SmallVector<SmallVector<Value>>>& tiles, llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& tiles,
size_t tileSize, size_t tileSize,
ConversionPatternRewriter& rewriter); mlir::ConversionPatternRewriter& rewriter);
/** /**
* Creates an ImgConcatOp based on the given tiles. * Creates an ImgConcatOp based on the given tiles.
@@ -159,10 +204,10 @@ void tileImageTensorByChannel(Value imageTensor,
* *
* @return The created ImgConcatOp. * @return The created ImgConcatOp.
*/ */
Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTiles, mlir::Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& outputTiles,
ConversionPatternRewriter& rewriter, mlir::ConversionPatternRewriter& rewriter,
Location& loc, mlir::Location& loc,
Type outputType); mlir::Type outputType);
/** /**
* @brief Verifies if the given input coordinates and padding values are within * @brief Verifies if the given input coordinates and padding values are within
@@ -177,7 +222,7 @@ Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTile
* @return LogicalResult Returns success if the coordinates and padding are * @return LogicalResult Returns success if the coordinates and padding are
* within bounds, failure otherwise. * within bounds, failure otherwise.
*/ */
LogicalResult mlir::LogicalResult
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y); verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y);
/** /**
@@ -207,13 +252,14 @@ verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY,
* @return std::optional<llvm::Twine> An error message if the input tensor could * @return std::optional<llvm::Twine> An error message if the input tensor could
* not be resolved into tiles. * not be resolved into tiles.
*/ */
std::optional<Twine> resolveImgInputTiles(Value wholeInputTensor, std::optional<llvm::Twine>
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles, resolveImgInputTiles(mlir::Value wholeInputTensor,
size_t channelTileCount, llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& inputTiles,
size_t channelTileRest, size_t channelTileCount,
size_t input_w, size_t channelTileRest,
size_t input_h, size_t input_w,
mlir::ConversionPatternRewriter& rewriter); size_t input_h,
mlir::ConversionPatternRewriter& rewriter);
/** /**
* Computes the boundaries of an image kernel application. * Computes the boundaries of an image kernel application.
@@ -258,6 +304,6 @@ void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcom
* @return The index of the result of the operation that produces the specified * @return The index of the result of the operation that produces the specified
* value. * value.
*/ */
int getResultIndex(Operation* op, Value v); int getResultIndex(mlir::Operation* op, mlir::Value v);
}; // namespace onnx_mlir }; // namespace onnx_mlir

View File

@@ -1,3 +1,4 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
@@ -10,19 +11,39 @@
#include "Common/PIMCommon.hpp" #include "Common/PIMCommon.hpp"
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" #include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "Math/Conv.hpp"
#include "ONNXToSpatialPass.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace spatial { bool haveSameStaticShape(Value lhs, Value rhs);
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
ONNXToSpatialPass() = default;
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
};
} // namespace
void ONNXToSpatialPass::runOnOperation() { void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
@@ -40,15 +61,19 @@ void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
IRRewriter rewriter(moduleOp); IRRewriter rewriter(moduleOp);
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin(); auto entryFunc = getPimEntryFunc(moduleOp);
if (annotateReplication(funcOp, rewriter).failed()) { if (failed(entryFunc)) {
signalPassFailure();
return;
}
if (annotateReplication(*entryFunc, rewriter).failed()) {
llvm::dbgs() << "Failed during annotation for replication analysis\n"; llvm::dbgs() << "Failed during annotation for replication analysis\n";
signalPassFailure(); signalPassFailure();
return; return;
} }
ConversionTarget target(*ctx); ConversionTarget target(*ctx);
target.addLegalDialect<ONNXDialect, SpatialDialect, tensor::TensorDialect, arith::ArithDialect, tosa::TosaDialect>(); target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
target.addIllegalOp<ONNXMatMulOp>(); target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXGemmOp>(); target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>(); target.addIllegalOp<ONNXConvOp>();
@@ -62,16 +87,9 @@ void ONNXToSpatialPass::runOnOperation() {
RewritePatternSet patterns(ctx); RewritePatternSet patterns(ctx);
patterns.add<removeLRNPattern>(ctx); patterns.add<removeLRNPattern>(ctx);
if (useExperimentalConvImpl) { populateConvOpPatterns(patterns, ctx);
populateExperimentalTilingConvOpPattern(patterns, ctx); populatePoolingTilingPattern(patterns, ctx);
populateExperimentalPoolingTilingPattern(patterns, ctx); populateOnnxGemmOpPatterns(patterns, ctx);
populateGemmToConvConversionPattern(patterns, ctx);
}
else {
populateTilingConvOpPattern(patterns, ctx);
populatePoolingTilingPattern(patterns, ctx);
populateOnnxGemmOpPatterns(patterns, ctx);
}
populateONNXConcatToTensorConcatPattern(patterns, ctx); populateONNXConcatToTensorConcatPattern(patterns, ctx);
populateReduceMeanConversionPattern(patterns, ctx); populateReduceMeanConversionPattern(patterns, ctx);
@@ -84,8 +102,8 @@ void ONNXToSpatialPass::runOnOperation() {
// Count the number of compute ops and check they do not exceed the core count // Count the number of compute ops and check they do not exceed the core count
if (coresCount != -1) { if (coresCount != -1) {
int computeOpsCount = 0; int computeOpsCount = 0;
for (auto& op : funcOp.getFunctionBody().front().getOperations()) for (auto& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<SpatWeightedCompute>(op)) if (isa<spatial::SpatWeightedCompute>(op))
computeOpsCount++; computeOpsCount++;
if (computeOpsCount > coresCount) { if (computeOpsCount > coresCount) {
@@ -102,22 +120,21 @@ void ONNXToSpatialPass::runOnOperation() {
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns)))) if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n"; llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
annotateWeightsConstants(funcOp); annotateWeightsConstants(*entryFunc);
// Dump to file for debug // Dump to file for debug
dumpModule(moduleOp, "spatial"); dumpModule(moduleOp, "spatial");
} }
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](arith::ConstantOp constantOp) { funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight = bool isAlwaysWeight =
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<SpatWeightedCompute>(user); }); llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
if (isAlwaysWeight) if (isAlwaysWeight)
constantOp->setAttr("weightAlways", UnitAttr::get(ctx)); markWeightAlways(constantOp);
}); });
} }
} // namespace spatial std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,34 +0,0 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
using namespace mlir;
extern bool haveSameStaticShape(Value lhs, Value rhs);
namespace spatial {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
ONNXToSpatialPass() = default;
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
};
} // namespace spatial
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<spatial::ONNXToSpatialPass>(); }
} // namespace onnx_mlir

View File

@@ -1,27 +1,20 @@
#pragma once #pragma once
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir { namespace onnx_mlir {
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateDistributeReducePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateFoldComputePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
// Experimental patterns.
void populateExperimentalTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmToConvConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateExperimentalPoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -10,7 +10,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
template <typename OpTy, typename OpAdaptorTy> template <typename OpTy, typename OpAdaptorTy>
struct RemoveUnusedHelperOps : public OpRewritePattern<OpTy> { struct RemoveUnusedHelperOps : OpRewritePattern<OpTy> {
RemoveUnusedHelperOps(MLIRContext* ctx) RemoveUnusedHelperOps(MLIRContext* ctx)
: OpRewritePattern<OpTy>(ctx) {} : OpRewritePattern<OpTy>(ctx) {}

View File

@@ -49,11 +49,11 @@ LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& r
ShapedType xShape = mlir::cast<ShapedType>(X.getType()); ShapedType xShape = mlir::cast<ShapedType>(X.getType());
ShapedType wShape = mlir::cast<ShapedType>(W.getType()); ShapedType wShape = mlir::cast<ShapedType>(W.getType());
size_t input_w = GET_IMAGE_WIDTH(xShape); size_t input_w = getImageWidth(xShape);
size_t krn_h = GET_KERNEL_HEIGHT(wShape); size_t krn_h = getKernelHeight(wShape);
size_t krn_w = GET_KERNEL_WIDTH(wShape); size_t krn_w = getKernelWidth(wShape);
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); size_t inputTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
size_t outputTileCount = ceilIntegerDivide(wShape.getDimSize(0), crossbarSize.getValue()); size_t outputTileCount = ceilIntegerDivide(wShape.getDimSize(0), crossbarSize.getValue());
auto neededXbars = krn_h * krn_w * inputTileCount * outputTileCount; auto neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;

View File

@@ -15,21 +15,21 @@
namespace onnx_mlir { namespace onnx_mlir {
llvm::SmallPtrSet<Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced; llvm::SmallPtrSet<mlir::Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum, ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value&)> processFun, std::function<mlir::Value(const mlir::Value&)> processFun,
ConversionPatternRewriter& rewriter) { mlir::ConversionPatternRewriter& rewriter) {
assert(processFun); assert(processFun);
auto computeOp = GET_COMP(computeOpAndResNum); auto computeOp = GET_COMP(computeOpAndResNum);
auto resultNum = GET_RES_NUM(computeOpAndResNum); auto resultNum = GET_RES_NUM(computeOpAndResNum);
spatial::SpatYieldOp yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator()); spatial::SpatYieldOp yieldOp = mlir::cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value result = yieldOp->getOperand(resultNum); mlir::Value result = yieldOp->getOperand(resultNum);
rewriter.setInsertionPointAfterValue(result); rewriter.setInsertionPointAfterValue(result);
Value processedResult = processFun(result); mlir::Value processedResult = processFun(result);
if (processedResult == result) { if (processedResult == result) {
// Sometimes we want processedResult to return the same value but do // Sometimes we want processedResult to return the same value but do
// something else with it (e.g. in softmax we want to broadcast the value // something else with it (e.g. in softmax we want to broadcast the value
@@ -42,10 +42,11 @@ ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum
return yieldOp.getNumOperands() - 1; return yieldOp.getNumOperands() - 1;
} }
OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum, OpAndResNum
std::function<Value(const Value&, const Value&)> reduce, SpatialReducer::applyReducePattern(llvm::SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<Value(const Value&)> preprocess, std::function<mlir::Value(const mlir::Value&, const mlir::Value&)> reduce,
std::function<Value(const Value&)> postprocess) { std::function<mlir::Value(const mlir::Value&)> preprocess,
std::function<mlir::Value(const mlir::Value&)> postprocess) {
if (preprocess) if (preprocess)
for (auto& computeOpAndResNum : computeOpsAndResNum) for (auto& computeOpAndResNum : computeOpsAndResNum)
@@ -55,18 +56,18 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
// computeOp. In this case, we need to apply the reduction within-computef // computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction // Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation*, Value> lastValueForCompute; std::unordered_map<mlir::Operation*, mlir::Value> lastValueForCompute;
for (auto& computeOpAndResNum : computeOpsAndResNum) { for (auto& computeOpAndResNum : computeOpsAndResNum) {
auto computeOp = GET_COMP(computeOpAndResNum); auto computeOp = GET_COMP(computeOpAndResNum);
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator()); auto yieldOp = mlir::cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum)); mlir::Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
auto it = lastValueForCompute.find(computeOp.getOperation()); auto it = lastValueForCompute.find(computeOp.getOperation());
if (it != lastValueForCompute.end()) { if (it != lastValueForCompute.end()) {
// If we have already seen this computeOp, apply the reduction // If we have already seen this computeOp, apply the reduction
// within-compute // within-compute
Value lastWithinComputeValue = it->second; mlir::Value lastWithinComputeValue = it->second;
assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp()); assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp());
@@ -85,12 +86,12 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
computeOpsAndResNum.clear(); computeOpsAndResNum.clear();
computeOpsAndResNum.reserve(lastValueForCompute.size()); computeOpsAndResNum.reserve(lastValueForCompute.size());
for (auto& entry : lastValueForCompute) { for (auto& entry : lastValueForCompute) {
auto computeOp = cast<spatial::SpatWeightedCompute>(entry.first); auto computeOp = mlir::cast<spatial::SpatWeightedCompute>(entry.first);
auto valueWithinCompute = entry.second; auto valueWithinCompute = entry.second;
// We check if `valueWithinCompute` is already used by the yieldOp, in that // We check if `valueWithinCompute` is already used by the yieldOp, in that
// case no need to add it // case no need to add it
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator()); auto yieldOp = mlir::cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
bool yieldOpUseFound = false; bool yieldOpUseFound = false;
for (auto& use : valueWithinCompute.getUses()) { for (auto& use : valueWithinCompute.getUses()) {
if (use.getOwner() == yieldOp.getOperation()) { if (use.getOwner() == yieldOp.getOperation()) {
@@ -110,7 +111,7 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
computeOpsAndResNum.push_back({computeOp, resultNum}); computeOpsAndResNum.push_back({computeOp, resultNum});
} }
Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc(); mlir::Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc();
// Recursive algorithm to reduce the inputs to a single one: // Recursive algorithm to reduce the inputs to a single one:
// - Take two inputs at a time, and reduce them into a single one, updating // - Take two inputs at a time, and reduce them into a single one, updating
@@ -118,7 +119,7 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
// - Repeat until there is only one input left. // - Repeat until there is only one input left.
llvm::OwningArrayRef<ComputeAndResNum> computeOpsRef(computeOpsAndResNum); llvm::OwningArrayRef<ComputeAndResNum> computeOpsRef(computeOpsAndResNum);
while (computeOpsRef.size() > 1) { while (computeOpsRef.size() > 1) {
SmallVector<ComputeAndResNum> nextComputeOps; llvm::SmallVector<ComputeAndResNum> nextComputeOps;
nextComputeOps.reserve(computeOpsRef.size() / 2); nextComputeOps.reserve(computeOpsRef.size() / 2);
for (size_t i = 0; i < computeOpsRef.size() - 1; i += 2) { for (size_t i = 0; i < computeOpsRef.size() - 1; i += 2) {
auto [firstCompute, firstResultNum] = computeOpsRef[i]; auto [firstCompute, firstResultNum] = computeOpsRef[i];
@@ -135,23 +136,23 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
// the number of results) // the number of results)
// See below `reducerChanges.push_back` and `finalizeReduceUpdates` // See below `reducerChanges.push_back` and `finalizeReduceUpdates`
auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator()); auto yieldOpFirstCompute = mlir::cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator());
// Add a new operand to the block of the second computeOp // Add a new operand to the block of the second computeOp
Block& secondBlock = secondCompute.getBody().front(); mlir::Block& secondBlock = secondCompute.getBody().front();
Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc); mlir::Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
auto secondComputeWeightsNum = auto secondComputeWeightsNum =
secondCompute->getAttrOfType<DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0]; secondCompute->getAttrOfType<mlir::DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0];
auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1; auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
// Take the "former-result" from the second computeOp // Take the "former-result" from the second computeOp
spatial::SpatYieldOp secondYield = cast<spatial::SpatYieldOp>(secondBlock.getTerminator()); spatial::SpatYieldOp secondYield = mlir::cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
Value formerRes2 = secondYield.getOperand(secondResultNum); mlir::Value formerRes2 = secondYield.getOperand(secondResultNum);
// Apply reduction operation // Apply reduction operation
rewriter.setInsertionPoint(secondYield); rewriter.setInsertionPoint(secondYield);
Value reduced = reduce(formerRes2, formerRes1); mlir::Value reduced = reduce(formerRes2, formerRes1);
// Unfortunately, it is not possible to update the result in place, // Unfortunately, it is not possible to update the result in place,
// because we may have already referenced it by <computeOp, resultNum> // because we may have already referenced it by <computeOp, resultNum>
@@ -219,7 +220,7 @@ void SpatialReducer::finalizeReduceUpdates() {
// `opToReplacedCompute` // `opToReplacedCompute`
auto toComputeOp = opToReplacedCompute[toOp]; auto toComputeOp = opToReplacedCompute[toOp];
if (!toComputeOp) if (!toComputeOp)
toComputeOp = cast<spatial::SpatWeightedCompute>(toOp); toComputeOp = mlir::cast<spatial::SpatWeightedCompute>(toOp);
assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!"); assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!");
@@ -234,31 +235,31 @@ void SpatialReducer::finalizeReduceUpdates() {
} }
} }
Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) { mlir::Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) {
assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates."); assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates.");
Operation* opToCast; mlir::Operation* opToCast;
auto it = opToReplacedCompute.find(opAndResNum.first); auto it = opToReplacedCompute.find(opAndResNum.first);
if (it != opToReplacedCompute.end()) if (it != opToReplacedCompute.end())
opToCast = it->second; opToCast = it->second;
else else
opToCast = opAndResNum.first; opToCast = opAndResNum.first;
auto computeOp = cast<spatial::SpatWeightedCompute>(opToCast); auto computeOp = mlir::cast<spatial::SpatWeightedCompute>(opToCast);
return computeOp.getResult(opAndResNum.second); return computeOp.getResult(opAndResNum.second);
} }
void SpatialReducer::updateResultsOfCompute(Operation* computeOp) { void SpatialReducer::updateResultsOfCompute(mlir::Operation* computeOp) {
if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) { if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) {
// If we have already replaced the fromOp, we do not need to do it again // If we have already replaced the fromOp, we do not need to do it again
return; return;
} }
auto oldComputeOp = cast<spatial::SpatWeightedCompute>(computeOp); auto oldComputeOp = mlir::cast<spatial::SpatWeightedCompute>(computeOp);
auto oldComputeOpNum = oldComputeOp->getNumOperands(); auto oldComputeOpNum = oldComputeOp->getNumOperands();
auto yieldOp = cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator()); auto yieldOp = mlir::cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) { if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) {
// No result was added, just add itself to the map // No result was added, just add itself to the map
@@ -283,8 +284,8 @@ void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
// Since we replaced the old ComputeOp with a new one, we need to replace // Since we replaced the old ComputeOp with a new one, we need to replace
// all its results' uses // all its results' uses
for (size_t i = 0; i < oldComputeOp.getNumResults(); i++) { for (size_t i = 0; i < oldComputeOp.getNumResults(); i++) {
Value oldResult = oldComputeOp.getResult(i); mlir::Value oldResult = oldComputeOp.getResult(i);
Value newResult = newComputeOp.getResult(i); mlir::Value newResult = newComputeOp.getResult(i);
// Replace the uses, except the uses of the compute ops which got deleted // Replace the uses, except the uses of the compute ops which got deleted
// previously // previously
@@ -298,9 +299,10 @@ void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
rewriter.eraseOp(oldComputeOp); rewriter.eraseOp(oldComputeOp);
} }
Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAndResNum>>>& outputTiles, mlir::Value
Location& loc, SpatialReducer::createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
Type outputType) { mlir::Location& loc,
mlir::Type outputType) {
assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates."); assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates.");
@@ -309,8 +311,8 @@ Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAn
auto width = outputTiles[0].size(); auto width = outputTiles[0].size();
auto height = outputTiles[0][0].size(); auto height = outputTiles[0][0].size();
SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles( llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>> remappedOutputTiles(
tilesCount, SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height))); tilesCount, llvm::SmallVector<llvm::SmallVector<mlir::Value>>(width, llvm::SmallVector<mlir::Value>(height)));
for (size_t t = 0; t < tilesCount; t++) for (size_t t = 0; t < tilesCount; t++)
for (size_t x = 0; x < width; x++) for (size_t x = 0; x < width; x++)
@@ -320,16 +322,16 @@ Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAn
return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType); return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType);
} }
OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps, OpAndResNum SpatialReducer::applyAddMapReduction(llvm::SmallVector<ComputeAndResNum>& computeOps,
ConversionPatternRewriter& rewriter, mlir::ConversionPatternRewriter& rewriter,
Value biasTile, mlir::Value biasTile,
MapOperations mapOp) { MapOperations mapOp) {
std::function<Value(const Value&)> postprocessing = nullptr; std::function<mlir::Value(const mlir::Value&)> postprocessing = nullptr;
if (mapOp != MapOperations::None) { if (mapOp != MapOperations::None) {
postprocessing = [&](const Value a) { postprocessing = [&](const mlir::Value a) {
Value mapOperand = a; mlir::Value mapOperand = a;
if (biasTile) if (biasTile)
mapOperand = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, biasTile); mapOperand = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, biasTile);
return createMapOperation(rewriter, mapOp, mapOperand); return createMapOperation(rewriter, mapOp, mapOperand);
@@ -338,7 +340,7 @@ OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector<ComputeAndResNum>&
return this->applyReducePattern( return this->applyReducePattern(
computeOps, computeOps,
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); }, [&](mlir::Value a, mlir::Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); },
/* preprocess = */ nullptr, /* preprocess = */ nullptr,
postprocessing); postprocessing);
} }

View File

@@ -3,6 +3,10 @@
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include <functional>
#include <unordered_map>
#include <utility>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -13,28 +17,28 @@ using ResNum = unsigned int;
using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>; using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>;
struct SpatialReducerChange { struct SpatialReducerChange {
Operation* fromOp; mlir::Operation* fromOp;
unsigned int fromOpResNum; unsigned int fromOpResNum;
Operation* toOp; mlir::Operation* toOp;
unsigned int toOpOperandNum; unsigned int toOpOperandNum;
}; };
using OpAndResNum = std::pair<Operation*, ResNum>; using OpAndResNum = std::pair<mlir::Operation*, ResNum>;
class SpatialReducer { class SpatialReducer {
public: public:
SpatialReducer(ConversionPatternRewriter& rewriter) SpatialReducer(mlir::ConversionPatternRewriter& rewriter)
: rewriter(rewriter) {} : rewriter(rewriter) {}
OpAndResNum applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum, OpAndResNum applyReducePattern(llvm::SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<Value(const Value&, const Value&)> reduce, std::function<mlir::Value(const mlir::Value&, const mlir::Value&)> reduce,
std::function<Value(const Value&)> preprocess, std::function<mlir::Value(const mlir::Value&)> preprocess,
std::function<Value(const Value&)> postprocess); std::function<mlir::Value(const mlir::Value&)> postprocess);
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps, OpAndResNum applyAddMapReduction(llvm::SmallVector<ComputeAndResNum>& computeOps,
ConversionPatternRewriter& rewriter, mlir::ConversionPatternRewriter& rewriter,
Value biasTile, mlir::Value biasTile,
MapOperations mapOp); MapOperations mapOp);
void finalizeReduceUpdates(); void finalizeReduceUpdates();
@@ -44,17 +48,17 @@ public:
finalizeReduceUpdates(); finalizeReduceUpdates();
} }
Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles, mlir::Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
Location& loc, mlir::Location& loc,
Type outputType); mlir::Type outputType);
Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum); mlir::Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
private: private:
[[nodiscard("computeOp result number gets updated")]] ResNum [[nodiscard("computeOp result number gets updated")]] ResNum
applyResultProcessing(ComputeAndResNum computeOpAndResNum, applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value&)> processFun, std::function<mlir::Value(const mlir::Value&)> processFun,
ConversionPatternRewriter& rewriter); mlir::ConversionPatternRewriter& rewriter);
/** /**
* @brief Update the results of a ComputeOp. * @brief Update the results of a ComputeOp.
@@ -66,19 +70,19 @@ private:
* *
* @param computeOp The ComputeOp to update the results of. * @param computeOp The ComputeOp to update the results of.
*/ */
void updateResultsOfCompute(Operation* computeOp); void updateResultsOfCompute(mlir::Operation* computeOp);
ConversionPatternRewriter& rewriter; mlir::ConversionPatternRewriter& rewriter;
bool reducesFinalized = false; bool reducesFinalized = false;
// List of changes to be applied after the reduction is finalized // List of changes to be applied after the reduction is finalized
SmallVector<SpatialReducerChange, 4> reducerChanges; llvm::SmallVector<SpatialReducerChange, 4> reducerChanges;
// List of computeOps that need to be replaced with new results // List of computeOps that need to be replaced with new results
SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate; llvm::SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
std::unordered_map<Operation*, spatial::SpatWeightedCompute> opToReplacedCompute; std::unordered_map<mlir::Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
static llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced; static llvm::SmallPtrSet<mlir::Operation*, 16> oldComputeOpsReplaced;
}; };
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -4,7 +4,7 @@
namespace onnx_mlir { namespace onnx_mlir {
WeightSubdivider::WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights) WeightSubdivider::WeightSubdivider(std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights)
: weights(std::move(weights)) {} : weights(std::move(weights)) {}
bool WeightSubdivider::isEmpty() const { return weights.empty(); } bool WeightSubdivider::isEmpty() const { return weights.empty(); }
@@ -13,7 +13,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
assert(!weights.empty() && "No weights to extract."); assert(!weights.empty() && "No weights to extract.");
auto it = weights.begin(); auto it = weights.begin();
SmallVector<Value>& values = it->second.begin()->second; llvm::SmallVector<mlir::Value>& values = it->second.begin()->second;
long inputTile = it->first; long inputTile = it->first;
long outputTile = it->second.begin()->first; long outputTile = it->second.begin()->first;
@@ -21,7 +21,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
size_t n = std::min(amount, values.size()); size_t n = std::min(amount, values.size());
crossbarsUsed += n; crossbarsUsed += n;
SmallVector<Value> result; llvm::SmallVector<mlir::Value> result;
result.assign(values.begin(), values.begin() + n); result.assign(values.begin(), values.begin() + n);
if (n < values.size()) { if (n < values.size()) {
@@ -36,9 +36,9 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
return {inputTile, outputTile, crossbarsUsed - n, result}; return {inputTile, outputTile, crossbarsUsed - n, result};
} }
SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) { llvm::SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) {
crossbarsUsed = 0; crossbarsUsed = 0;
SmallVector<TaggedWeights> result; llvm::SmallVector<TaggedWeights> result;
size_t remaining = n; size_t remaining = n;
while (remaining > 0 && !weights.empty()) { while (remaining > 0 && !weights.empty()) {

View File

@@ -4,11 +4,9 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <map> #include <map>
using namespace mlir;
using namespace std;
namespace onnx_mlir { namespace onnx_mlir {
/** /**
@@ -19,7 +17,7 @@ struct TaggedWeights {
long inputTile; long inputTile;
long outputTile; long outputTile;
size_t startingCrossbarIndex; size_t startingCrossbarIndex;
SmallVector<Value> weights; llvm::SmallVector<mlir::Value> weights;
}; };
/** /**
@@ -33,16 +31,16 @@ struct TaggedWeights {
*/ */
class WeightSubdivider { class WeightSubdivider {
private: private:
map<long, map<long, SmallVector<Value>>> weights; std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights;
size_t crossbarsUsed = 0; size_t crossbarsUsed = 0;
TaggedWeights popGroup(size_t amount); TaggedWeights popGroup(size_t amount);
public: public:
WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights); WeightSubdivider(std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights);
bool isEmpty() const; bool isEmpty() const;
SmallVector<TaggedWeights> popGroups(size_t n); llvm::SmallVector<TaggedWeights> popGroups(size_t n);
}; };
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -10,6 +10,7 @@
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/Format.h" #include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -199,12 +200,12 @@ private:
void SpatialToGraphvizPass::runOnOperation() { void SpatialToGraphvizPass::runOnOperation() {
ModuleOp module = getOperation(); ModuleOp module = getOperation();
// Get the first OP, must be a FuncOp auto entryFunc = getPimEntryFunc(module);
func::FuncOp func = *module.getOps<func::FuncOp>().begin(); if (failed(entryFunc)) {
if (!func) {
module->emitError("No FuncOp found in the begin of module");
signalPassFailure(); signalPassFailure();
return;
} }
func::FuncOp func = *entryFunc;
os << "digraph G {\n" os << "digraph G {\n"
<< "\tnode [style=filled,color=white];\n"; << "\tnode [style=filled,color=white];\n";

View File

@@ -3,7 +3,6 @@ mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPIMIncGen) add_public_tablegen_target(SpatialToPIMIncGen)
add_onnx_mlir_library(OMSpatialToPIM add_onnx_mlir_library(OMSpatialToPIM
SpatialToPIMPass.hpp
SpatialToPIMPass.cpp SpatialToPIMPass.cpp
SpatialToPIMCommon.cpp SpatialToPIMCommon.cpp

View File

@@ -3,10 +3,18 @@
#ifndef OP_BASE #ifndef OP_BASE
include "mlir/IR/PatternBase.td" include "mlir/IR/PatternBase.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "src/Dialect/ONNX/ONNX.td"
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td" include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/PIM/Pim.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE #endif // OP_BASE
def onnxToPimTransposeOp : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVMMOp : Pat< def spatToPimVMMOp : Pat<
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector), (SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector, (PimVMMOp $weightIndex, $vector,

View File

@@ -8,6 +8,7 @@
#include "SpatialToPIMCommon.hpp" #include "SpatialToPIMCommon.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
@@ -90,8 +91,8 @@ Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Opera
auto resultShapedType = cast<ShapedType>(resultType); auto resultShapedType = cast<ShapedType>(resultType);
rewriter.setInsertionPoint(operation); rewriter.setInsertionPoint(operation);
return rewriter.create<tensor::EmptyOp>( return tensor::EmptyOp::create(
operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType()); rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -4,8 +4,6 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
/** /**
@@ -20,10 +18,10 @@ namespace onnx_mlir {
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor * \param inputShape The ShapedType of the ExtractSliceOp's input tensor
* \return The actual offset of the ExtractSliceOp. * \return The actual offset of the ExtractSliceOp.
*/ */
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape); size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
template <class T> template <class T>
size_t rangeLength(const iterator_range<T> range) { size_t rangeLength(const mlir::iterator_range<T> range) {
return std::distance(range.begin(), range.end()); return std::distance(range.begin(), range.end());
} }
@@ -35,16 +33,16 @@ size_t rangeLength(const iterator_range<T> range) {
* @return The earliest user operation that uses the given value within the * @return The earliest user operation that uses the given value within the
* current block. * current block.
*/ */
Operation* getEarliestUserWithinBlock(Value value); mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation); mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation); mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
static bool isMemoryContiguous(const ArrayRef<int64_t> srcShape, static bool isMemoryContiguous(const mlir::ArrayRef<int64_t> srcShape,
const ArrayRef<int64_t> offsets, const mlir::ArrayRef<int64_t> offsets,
const ArrayRef<int64_t> sizes, const mlir::ArrayRef<int64_t> sizes,
const ArrayRef<int64_t> strides) { const mlir::ArrayRef<int64_t> strides) {
// Check that all strides are 1 // Check that all strides are 1
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; })) if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false; return false;
@@ -99,10 +97,13 @@ static bool isMemoryContiguous(const ArrayRef<int64_t> srcShape,
return true; return true;
} }
inline tensor::EmptyOp createEmptyTensorFromShaped(IRRewriter& rewriter, Location loc, ShapedType shapedType) { inline mlir::tensor::EmptyOp
return rewriter.create<tensor::EmptyOp>(loc, shapedType.getShape(), shapedType.getElementType()); createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
} }
inline bool isAConcatOp(Operation* op) { return isa<tensor::ConcatOp>(op) || isa<spatial::SpatImgConcatOp>(op); } inline bool isAConcatOp(mlir::Operation* op) {
return isa<mlir::tensor::ConcatOp>(op) || isa<spatial::SpatImgConcatOp>(op);
}
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -17,12 +17,64 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "SpatialToPIMPass.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir; using namespace mlir;
using namespace onnx_mlir; using namespace onnx_mlir;
using namespace pim; using namespace pim;
using namespace spat_to_pim;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
struct SpatialToPIMPass : PassWrapper<SpatialToPIMPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPIMPass() = default;
SpatialToPIMPass(const SpatialToPIMPass& pass) {}
void runOnOperation() final;
private:
SmallVector<Value> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value& channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
};
} // namespace
void SpatialToPIMPass::runOnOperation() { void SpatialToPIMPass::runOnOperation() {
coreId = 1; coreId = 1;
@@ -35,14 +87,17 @@ void SpatialToPIMPass::runOnOperation() {
RewritePatternSet patterns(ctx); RewritePatternSet patterns(ctx);
populateWithGenerated(patterns); populateWithGenerated(patterns);
if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) { if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure(); signalPassFailure();
return; return;
} }
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin(); auto entryFunc = getPimEntryFunc(moduleOp);
if (!funcOp) if (failed(entryFunc)) {
llvm_unreachable("No FuncOp found in the begin of module"); signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator()); auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
@@ -260,7 +315,7 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
for (auto returnValue : returnOp->getOperands()) { for (auto returnValue : returnOp->getOperands()) {
Operation* returnValueDefiningOp = returnValue.getDefiningOp(); Operation* returnValueDefiningOp = returnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) { if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!returnValueDefiningOp->hasAttr("weightAlways")); assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back(returnValue); outputTensors.push_back(returnValue);
} }
else { else {
@@ -487,3 +542,7 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData()); rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
} }
} }
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<SpatialToPIMPass>(); }
} // namespace onnx_mlir

View File

@@ -1,60 +0,0 @@
#pragma once
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
namespace spat_to_pim {
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
struct SpatialToPIMPass : PassWrapper<SpatialToPIMPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPIMPass() = default;
SpatialToPIMPass(const SpatialToPIMPass& pass) {}
void runOnOperation() final;
private:
SmallVector<Value> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value& channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
};
} // namespace spat_to_pim
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<spat_to_pim::SpatialToPIMPass>(); }
} // namespace onnx_mlir

View File

@@ -1,2 +1,2 @@
add_subdirectory(PIM) add_subdirectory(Pim)
add_subdirectory(Spatial) add_subdirectory(Spatial)

View File

@@ -1,15 +0,0 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
void registerOpBufferizationInterfaces(DialectRegistry& registry);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -1,34 +0,0 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include "Dialect/PIM/PimOps.hpp"
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
namespace pim {
#include "Dialect/PIM/Transforms/Bufferization/PimBufferization.hpp.inc"
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
StringRef getArgument() const override { return "bufferize-pim"; }
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
PimBufferizationPass() = default;
PimBufferizationPass(const PimBufferizationPass& pass) {}
void runOnOperation() final;
private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
};
} // namespace pim
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<pim::PimBufferizationPass>(); }
} // namespace onnx_mlir

View File

@@ -175,7 +175,33 @@ def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
}]; }];
} }
// Computation // Algebra
def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> {
let description = [{
Matrix transpose
}];
let arguments = (ins
PimTensor: $data,
I64ArrayAttr: $perms,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> { def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
let description = [{ let description = [{
@@ -197,6 +223,10 @@ def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
return getOutBufMutable(); return getOutBufMutable();
} }
}]; }];
let assemblyFormat = [{
`(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes)
}];
} }
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> { def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {

View File

@@ -10,7 +10,7 @@
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallBitVector.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
@@ -20,7 +20,7 @@ namespace pim {
void PimDialect::initialize() { void PimDialect::initialize() {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
>(); >();
} }
@@ -45,5 +45,5 @@ POPULATE_DEPENDENCIES(PimVExpOp)
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.cpp.inc" #include "src/Accelerators/PIM/Dialect/Pim/PimDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"

View File

@@ -12,7 +12,7 @@
#include <string> #include <string>
/// Include the auto-generated header files containing the declarations /// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc" #include "src/Accelerators/PIM/Dialect/Pim/PimDialect.hpp.inc"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp.inc" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp.inc"

View File

@@ -3,7 +3,6 @@ mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(PimBufferizationIncGen) add_public_tablegen_target(PimBufferizationIncGen)
add_onnx_mlir_library(OMPimBufferization add_onnx_mlir_library(OMPimBufferization
PimBufferizationPass.hpp
PimBufferizationPass.cpp PimBufferizationPass.cpp
OpBufferizationInterfaces.hpp OpBufferizationInterfaces.hpp
OpBufferizationInterfaces.cpp OpBufferizationInterfaces.cpp

View File

@@ -1,4 +1,4 @@
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
using namespace mlir; using namespace mlir;

View File

@@ -2,12 +2,10 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace pim {
IntegerAttr getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref); mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref);
} // namespace pim } // namespace pim
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -4,7 +4,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "OpBufferizationInterfaces.hpp" #include "OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
using namespace bufferization; using namespace bufferization;
@@ -76,6 +76,32 @@ struct MemCopyDevToHostOpInterface
} }
}; };
struct TransposeOpBufferizeInterface
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto transposeOp = cast<PimTransposeOp>(op);
auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state);
if (failed(dataOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimTransposeOp>(
rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt);
return success();
}
};
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> { struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -176,6 +202,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) { registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx); PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx); PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx); PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx); PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx); PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);

View File

@@ -0,0 +1,13 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
namespace pim {
void registerOpBufferizationInterfaces(mlir::DialectRegistry& registry);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -4,7 +4,7 @@
#ifndef OP_BASE #ifndef OP_BASE
include "mlir/IR/PatternBase.td" include "mlir/IR/PatternBase.td"
include "mlir/Dialect/MemRef/IR/MemRefOps.td" include "mlir/Dialect/MemRef/IR/MemRefOps.td"
include "src/Accelerators/PIM/Dialect/PIM/Pim.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE #endif // OP_BASE
def memrefCopyToPimMemCopyOp : Pat< def memrefCopyToPimMemCopyOp : Pat<

View File

@@ -7,12 +7,37 @@
#include "Common/PIMCommon.hpp" #include "Common/PIMCommon.hpp"
#include "Compiler/PimCodeGen.hpp" #include "Compiler/PimCodeGen.hpp"
#include "PimBufferizationPass.hpp" #include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir; using namespace mlir;
using namespace onnx_mlir; using namespace onnx_mlir;
using namespace pim; using namespace pim;
namespace onnx_mlir {
namespace {
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
StringRef getArgument() const override { return "bufferize-pim"; }
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
PimBufferizationPass() = default;
PimBufferizationPass(const PimBufferizationPass& pass) {}
void runOnOperation() final;
private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
};
} // namespace
void PimBufferizationPass::runOnOperation() { void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation(); auto moduleOp = getOperation();
@@ -68,15 +93,18 @@ void PimBufferizationPass::runOnOperation() {
} }
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
bool isAlwaysWeight = !getGlobalOp->getUsers().empty() bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); }); && all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
if (isAlwaysWeight) { if (isAlwaysWeight) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName()); auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
assert("Weights must be constants" && globalMemrefOp.getConstant()); assert("Weights must be constants" && globalMemrefOp.getConstant());
getGlobalOp->setAttr("weightAlways", UnitAttr::get(ctx)); markWeightAlways(getGlobalOp);
globalMemrefOp->setAttr("weightAlways", UnitAttr::get(ctx)); markWeightAlways(globalMemrefOp);
} }
}); });
} }
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); }
} // namespace onnx_mlir

View File

@@ -25,7 +25,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;
@@ -202,9 +202,9 @@ LogicalResult SpatVMaxOp::verify() {
LogicalResult SpatImgConcatOp::verify() { LogicalResult SpatImgConcatOp::verify() {
auto imgShape = mlir::cast<ShapedType>(getType()); auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape); size_t img_w = getImageWidth(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape); size_t img_h = getImageHeight(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape); size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue()); size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
size_t channelTileRest = img_c % crossbarSize; size_t channelTileRest = img_c % crossbarSize;
@@ -225,10 +225,10 @@ LogicalResult SpatImgConcatOp::verify() {
return emitError("Invalid input type, must be ShapedType"); return emitError("Invalid input type, must be ShapedType");
// N == W == H == 1 // N == W == H == 1
if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1) if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
return emitError("Invalid input shape: N,W,H must all be 1"); return emitError("Invalid input shape: N,W,H must all be 1");
size_t inputChannels = GET_IMAGE_CHANNEL(inputShape); size_t inputChannels = getImageChannel(inputShape);
// Check the number of channels in this tile are correct: // Check the number of channels in this tile are correct:
// - CASE1: last tile of pixel, if there is some rest it must match that // - CASE1: last tile of pixel, if there is some rest it must match that
@@ -311,9 +311,9 @@ LogicalResult SpatWeightedCompute::verify() {
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) { Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
auto operands = getOperands(); auto operands = getOperands();
auto imgShape = mlir::cast<ShapedType>(getType()); auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape); size_t img_w = getImageWidth(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape); size_t img_h = getImageHeight(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape); size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue()); size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());

View File

@@ -18,7 +18,7 @@
#include <cstdint> #include <cstdint>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"

View File

@@ -4,14 +4,12 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace spatial { namespace spatial {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry); void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry); void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,7 +1,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerUtils.hpp" #include "src/Compiler/CompilerUtils.hpp"

View File

@@ -0,0 +1,181 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides(rank, 1);
SmallVector<int64_t> transposedStrides(rank, 1);
for (int64_t dim = rank - 2; dim >= 0; --dim) {
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
}
SmallVector<int64_t> originalIndices(rank);
SmallVector<int64_t> transposedIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
for (int64_t dim = 0; dim < rank; ++dim)
transposedIndices[dim] = originalIndices[perms[dim]];
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
if (!sourceGetGlobal)
return failure();
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
if (!denseAttr)
return failure();
SmallVector<int64_t> perms;
perms.reserve(transposeOp.getPerms().size());
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
perms.push_back(attr.getInt());
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
if (failed(transposedAttr))
return failure();
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
if (!llvm::equal(transposedShape, resultType.getShape()))
return failure();
MemRefType globalType = resultType;
auto globalName = sourceGlobal.getName().str() + "__folded_transpose";
unsigned suffix = 0;
while (moduleOp.lookupSymbol(globalName))
globalName = sourceGlobal.getName().str() + "__folded_transpose_" + std::to_string(++suffix);
auto visibility = rewriter.getStringAttr("private");
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
auto newGlobal = memref::GlobalOp::create(moduleBuilder,
transposeOp.getLoc(),
globalName,
visibility,
globalType,
*transposedAttr,
/*constant=*/true,
sourceGlobal.getAlignmentAttr());
rewriter.setInsertionPoint(transposeOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
bool isAlwaysWeight =
!transposeOp->getUsers().empty()
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
if (isAlwaysWeight) {
markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal);
}
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
return success();
}
};
struct PimFoldHostConstantsPass : PassWrapper<PimFoldHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimFoldHostConstantsPass)
StringRef getArgument() const override { return "fold-pim-host-constants-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override {
RewritePatternSet owningPatterns(context);
for (auto* dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
owningPatterns.add<FoldConstantTransposePattern>(context);
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();
}
void runOnOperation() override {
GreedyRewriteConfig config;
config.enableFolding();
if (failed(applyPatternsGreedily(getOperation(), *patterns, config)))
signalPassFailure();
}
std::shared_ptr<const FrozenRewritePatternSet> patterns;
};
} // namespace
std::unique_ptr<Pass> createPimFoldHostConstantsPass() { return std::make_unique<PimFoldHostConstantsPass>(); }
} // namespace onnx_mlir

View File

@@ -0,0 +1,173 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isAddressOnlyHostOp(Operation* op) {
return isa<memref::AllocOp,
memref::GetGlobalOp,
memref::SubViewOp,
memref::CastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp,
spatial::SpatChannelNewOp>(op);
}
static bool isHostAddressableValue(Value value) {
while (true) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return isa<func::FuncOp>(blockArg.getOwner()->getParentOp());
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return true;
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
return false;
}
}
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass)
StringRef getArgument() const override { return "verify-pim-host-pass"; }
StringRef getDescription() const override {
return "Verify that no runtime host-side code remains in bufferized PIM IR";
}
PimHostVerificationPass() {}
PimHostVerificationPass(const PimHostVerificationPass& pass) {}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (Operation& op : funcOp.getBody().front().getOperations()) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
if (failed(verifyCoreWeights(moduleOp, coreOp)))
hasFailure = true;
continue;
}
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
if (failed(verifyReturnOp(returnOp)))
hasFailure = true;
continue;
}
if (!isAddressOnlyHostOp(&op)) {
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; "
"fold it to constants or lower it into pim.core");
hasFailure = true;
continue;
}
if (failed(verifyAddressOnlyHostOp(&op)))
hasFailure = true;
}
}
if (hasFailure)
signalPassFailure();
}
private:
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, pim::PimCoreOp coreOp) {
bool hasFailure = false;
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must be materialized as memref.get_global before JSON codegen";
hasFailure = true;
continue;
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
hasFailure = true;
continue;
}
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must come from a constant memref.global with an initial value";
hasFailure = true;
}
}
return success(!hasFailure);
}
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
bool hasFailure = false;
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
if (!isHostAddressableValue(operand)) {
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by host-addressable storage";
hasFailure = true;
}
}
return success(!hasFailure);
}
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
return verifyAddressOnlySource(op, subviewOp.getSource());
if (auto castOp = dyn_cast<memref::CastOp>(op))
return verifyAddressOnlySource(op, castOp.getSource());
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
return verifyAddressOnlySource(op, collapseOp.getSrc());
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
return verifyAddressOnlySource(op, expandOp.getSrc());
return success();
}
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
if (isHostAddressableValue(source))
return success();
op->emitOpError("depends on a value that still requires host-side execution");
return failure();
}
};
} // namespace
std::unique_ptr<Pass> createPimHostVerificationPass() { return std::make_unique<PimHostVerificationPass>(); }
} // namespace onnx_mlir

View File

@@ -3,23 +3,26 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include <memory> #include <memory>
#include <string>
using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
std::unique_ptr<Pass> createONNXToSpatialPass(); std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
std::unique_ptr<Pass> createSpatialToGraphvizPass(); std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
std::unique_ptr<Pass> createSpatialToPIMPass(); std::unique_ptr<mlir::Pass> createSpatialToPIMPass();
std::unique_ptr<Pass> createBufferizePimPass(); std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<Pass> createEmitPimJsonPass(); std::unique_ptr<mlir::Pass> createPimFoldHostConstantsPass();
std::unique_ptr<Pass> createMessagePass(std::string message); std::unique_ptr<mlir::Pass> createPimHostVerificationPass();
std::unique_ptr<Pass> createCountInstructionPass(); std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
std::unique_ptr<mlir::Pass> createMessagePass(std::string message);
std::unique_ptr<mlir::Pass> createCountInstructionPass();
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -4,6 +4,7 @@
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h"
@@ -12,8 +13,8 @@
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
@@ -40,28 +41,27 @@ PimAccelerator::PimAccelerator()
acceleratorTargets.push_back(this); acceleratorTargets.push_back(this);
} }
PimAccelerator::~PimAccelerator() { delete instance; }
uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; } uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; }
void PimAccelerator::addPasses(OwningOpRef<ModuleOp>& module, void PimAccelerator::addPasses(mlir::OwningOpRef<mlir::ModuleOp>& module,
PassManager& pm, mlir::PassManager& pm,
EmissionTargetType& emissionTarget, EmissionTargetType& emissionTarget,
std::string outputNameNoExt) const { std::string outputNameNoExt) const {
LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n"); LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n");
addPassesPim(module, pm, emissionTarget, outputNameNoExt); addPassesPim(module, pm, emissionTarget, outputNameNoExt);
} }
void PimAccelerator::registerDialects(DialectRegistry& registry) const { void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n"); LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n");
registry.insert<tensor::TensorDialect>(); registry.insert<mlir::tensor::TensorDialect>();
registry.insert<tosa::TosaDialect>(); registry.insert<mlir::tosa::TosaDialect>();
registry.insert<bufferization::BufferizationDialect>(); registry.insert<mlir::bufferization::BufferizationDialect>();
registry.insert<pim::PimDialect>(); registry.insert<pim::PimDialect>();
registry.insert<spatial::SpatialDialect>(); registry.insert<spatial::SpatialDialect>();
tensor::registerBufferizableOpInterfaceExternalModels(registry); mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry); mlir::tensor::registerInferTypeOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry); mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerBufferizableOpInterfaceExternalModels(registry); spatial::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry); spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
pim::registerOpBufferizationInterfaces(registry); pim::registerOpBufferizationInterfaces(registry);
@@ -73,6 +73,8 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPIMPass); registerPass(createSpatialToPIMPass);
registerPass(createBufferizePimPass); registerPass(createBufferizePimPass);
registerPass(createPimFoldHostConstantsPass);
registerPass(createPimHostVerificationPass);
registerPass(createEmitPimJsonPass); registerPass(createEmitPimJsonPass);
} }
@@ -81,26 +83,26 @@ void PimAccelerator::configurePasses() const {
// TODO: This does nothing for now. // TODO: This does nothing for now.
} }
MemRefType PimAccelerator::convertTensorTypeToMemRefType(const TensorType tensorType) const { mlir::MemRefType PimAccelerator::convertTensorTypeToMemRefType(const mlir::TensorType tensorType) const {
// Do not convert tensor types to memref types. // Do not convert tensor types to memref types.
return nullptr; return nullptr;
} }
void PimAccelerator::conversionTargetONNXToKrnl(ConversionTarget& target) const { void PimAccelerator::conversionTargetONNXToKrnl(mlir::ConversionTarget& target) const {
target.addLegalDialect<pim::PimDialect>(); target.addLegalDialect<pim::PimDialect>();
} }
void PimAccelerator::rewritePatternONNXToKrnl(RewritePatternSet& patterns, void PimAccelerator::rewritePatternONNXToKrnl(mlir::RewritePatternSet& patterns,
TypeConverter& typeConverter, mlir::TypeConverter& typeConverter,
MLIRContext* ctx) const { mlir::MLIRContext* ctx) const {
// TODO: Add patterns for conversion // TODO: Add patterns for conversion
} }
void PimAccelerator::conversionTargetKrnlToLLVM(ConversionTarget& target) const {} void PimAccelerator::conversionTargetKrnlToLLVM(mlir::ConversionTarget& target) const {}
void PimAccelerator::rewritePatternKrnlToLLVM(RewritePatternSet& patterns, void PimAccelerator::rewritePatternKrnlToLLVM(mlir::RewritePatternSet& patterns,
LLVMTypeConverter& typeConverter, mlir::LLVMTypeConverter& typeConverter,
MLIRContext* ctx) const { mlir::MLIRContext* ctx) const {
// We should not need this, since we offload it all to PIM. // We should not need this, since we offload it all to PIM.
} }

View File

@@ -18,8 +18,6 @@ public:
PimAccelerator(PimAccelerator&) = delete; PimAccelerator(PimAccelerator&) = delete;
void operator=(const PimAccelerator&) = delete; void operator=(const PimAccelerator&) = delete;
~PimAccelerator();
/// Creates an instance on the first invocation. Subsequent invocations /// Creates an instance on the first invocation. Subsequent invocations
/// return the existing instance. /// return the existing instance.
static PimAccelerator* getInstance(); static PimAccelerator* getInstance();

View File

@@ -102,46 +102,13 @@ def gen_c(inputs, outputs, entry, so_name):
if(!in{i}_tensor){{fprintf(stderr,"ERROR: omTensorCreateWithOwnership failed for input {i}.\\n");return 2;}} if(!in{i}_tensor){{fprintf(stderr,"ERROR: omTensorCreateWithOwnership failed for input {i}.\\n");return 2;}}
""")) """))
# Output printing + optional per-output CSV dump # Optional per-output CSV dump
out_blocks=[]
csv_write_blocks=[] csv_write_blocks=[]
for oi,name,et,shape in outputs: for oi,name,et,shape in outputs:
if et not in DTYPES: if et not in DTYPES:
raise ValueError(f"Unsupported dtype for output '{name}': {et}") raise ValueError(f"Unsupported dtype for output '{name}': {et}")
cty, pfmt, _ = DTYPES[et] cty, pfmt, _ = DTYPES[et]
safe = esc(name) safe = esc(name)
out_blocks.append(textwrap.dedent(f"""
// ---- Output {oi}: "{safe}" ----
{{
OMTensor *t = omTensorListGetOmtByIndex(out_list, {oi});
int64_t rank = omTensorGetRank(t);
int64_t const *shape = omTensorGetShape(t);
long long numel = 1; for (int64_t k=0;k<rank;k++) numel *= shape[k];
{cty} *p = ({cty}*)omTensorGetDataPtr(t);
printf("Output {oi} ('{safe}'): shape=[");
for (int64_t k=0;k<rank;k++) printf("%ld%s",(long)shape[k], (k+1<rank)?",":"");
printf("]\\n");
if (rank == 2) {{
int64_t R = shape[0], C = shape[1];
for (int64_t r=0; r<R; ++r) {{
for (int64_t c=0; c<C; ++c) {{
long long idx = r*C + c;
printf("{pfmt}%s", p[idx], (c+1<C)?", ":"");
}}
printf("\\n");
}}
}} else {{
// Flattened vector with indices
for (long long i=0;i<numel;i++) {{
printf("[%lld]={pfmt}%s", i, p[i], (i+1<numel)?", ":"\\n");
}}
}}
}}
"""))
# Per-output CSV writer into --save-csv-dir
csv_write_blocks.append(textwrap.dedent(f""" csv_write_blocks.append(textwrap.dedent(f"""
if (save_csv_dir) {{ if (save_csv_dir) {{
// Build "DIR/output{oi}_<sanitized name>.csv" // Build "DIR/output{oi}_<sanitized name>.csv"
@@ -227,9 +194,6 @@ int main(int argc, char **argv) {{
OMTensorList *out_list = {entry}(in_list); OMTensorList *out_list = {entry}(in_list);
if(!out_list){{fprintf(stderr,"ERROR: model returned NULL.\\n");omTensorListDestroy(in_list);return 3;}} if(!out_list){{fprintf(stderr,"ERROR: model returned NULL.\\n");omTensorListDestroy(in_list);return 3;}}
// ---- Print full outputs ----
{"".join(out_blocks)}
// ---- Optional per-output CSV dump ---- // ---- Optional per-output CSV dump ----
{"".join(csv_write_blocks)} {"".join(csv_write_blocks)}
@@ -240,7 +204,7 @@ int main(int argc, char **argv) {{
}} }}
""" """
def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None): def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None, verbose=True):
ins, outs = onnx_io(network_onnx) ins, outs = onnx_io(network_onnx)
out_c = out or "runner.c" out_c = out or "runner.c"
so_abs = os.path.abspath(network_so) so_abs = os.path.abspath(network_so)
@@ -260,8 +224,9 @@ set_target_properties(model_so PROPERTIES IMPORTED_LOCATION {esc(so_abs)})
target_link_libraries({pathlib.Path(out_c).stem} PUBLIC model_so) target_link_libraries({pathlib.Path(out_c).stem} PUBLIC model_so)
""" """
pathlib.Path(out_c).with_name("CMakeLists.txt").write_text(cmake) pathlib.Path(out_c).with_name("CMakeLists.txt").write_text(cmake)
print(f"[OK] Wrote {out_c}") if verbose:
print("[OK] Wrote CMakeLists.txt") print(f"[OK] Wrote {out_c}")
print("[OK] Wrote CMakeLists.txt")
if __name__=="__main__": if __name__=="__main__":
ap=argparse.ArgumentParser() ap=argparse.ArgumentParser()

View File

@@ -1,9 +1,10 @@
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from colorama import Fore, Style from colorama import Fore, Style
from subprocess_utils import run_command_with_reporter
def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, crossbar_count): def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, crossbar_count, reporter=None):
# Define the arguments, with the possibility to set crossbar size and count # Define the arguments, with the possibility to set crossbar size and count
args = [ args = [
network_path, network_path,
@@ -14,16 +15,14 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, cro
f"--crossbar-count={crossbar_count}", f"--crossbar-count={crossbar_count}",
] ]
# Run the executable with the arguments
try: try:
result = subprocess.run( run_command_with_reporter(
[str(raptor_onnx_path)] + [str(arg) for arg in args], [str(raptor_onnx_path)] + [str(arg) for arg in args],
check=True, reporter=reporter,
capture_output=True,
text=True,
) )
print(result.stdout + Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL) if reporter is None:
except subprocess.CalledProcessError as e: print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
print(Fore.RED + "Error executing ONNX-MLIR:") except subprocess.CalledProcessError:
print(e.stderr + Style.RESET_ALL) if reporter is None:
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)
raise raise

View File

@@ -0,0 +1,70 @@
import errno
import os
import pty
import selectors
import subprocess
def _read_chunk(fd, treat_eio_as_eof=False):
try:
return os.read(fd, 4096)
except OSError as exc:
if treat_eio_as_eof and exc.errno == errno.EIO:
return b""
raise
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
selector = selectors.DefaultSelector()
try:
selector.register(fd, selectors.EVENT_READ)
while selector.get_map():
for key, _ in selector.select():
data = _read_chunk(key.fileobj, treat_eio_as_eof=treat_eio_as_eof)
if not data:
selector.unregister(key.fileobj)
os.close(key.fileobj)
continue
reporter._clear()
os.write(1, data)
reporter._render()
finally:
selector.close()
return_code = process.wait()
if return_code != 0:
raise subprocess.CalledProcessError(return_code, process.args)
def run_command_with_reporter(cmd, cwd=None, reporter=None):
if reporter is None:
subprocess.run(cmd, cwd=cwd, check=True)
return
try:
master_fd, slave_fd = pty.openpty()
except OSError:
process = subprocess.Popen(
cmd,
cwd=cwd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
assert process.stdout is not None
_stream_output(process.stdout.fileno(), process, reporter)
return
try:
process = subprocess.Popen(
cmd,
cwd=cwd,
stdout=slave_fd,
stderr=slave_fd,
)
finally:
os.close(slave_fd)
_stream_output(master_fd, process, reporter, treat_eio_as_eof=True)

View File

@@ -1,10 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from colorama import Style, Fore from colorama import Style, Fore
from validate_one import validate_network from validate_one import ProgressReporter, validate_network
def main(): def main():
@@ -34,32 +35,48 @@ def main():
print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL) print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL)
sys.exit(1) sys.exit(1)
print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate.\n" + Style.RESET_ALL) print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate." + Style.RESET_ALL)
print(f"Operations root: {operations_dir}")
print("=" * 72)
results = {} # relative_path -> passed results = {} # relative_path -> passed
for onnx_path in onnx_files: reporter = ProgressReporter(len(onnx_files))
for index, onnx_path in enumerate(onnx_files, start=1):
rel = onnx_path.relative_to(operations_dir) rel = onnx_path.relative_to(operations_dir)
header = f"{'=' * 60}\n Validating: {rel}\n{'=' * 60}" try:
print(Style.BRIGHT + Fore.CYAN + header + Style.RESET_ALL) passed = validate_network(
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
threshold=a.threshold,
reporter=reporter,
model_index=index,
model_total=len(onnx_files),
)
results[str(rel)] = passed
except (subprocess.CalledProcessError, Exception):
results[str(rel)] = False
passed = validate_network( reporter.finish()
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
threshold=a.threshold,
)
results[str(rel)] = passed
# Summary # Summary
n_passed = sum(results.values()) n_passed = sum(1 for passed in results.values() if passed)
n_total = len(results) n_total = len(results)
print("\n" + Style.BRIGHT + "=" * 60) status_width = len("Result")
print(" Summary") path_width = max(len("Operation"), *(len(rel) for rel in results))
print("=" * 60 + Style.RESET_ALL) separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+"
print("\n" + Style.BRIGHT + Fore.CYAN + "Summary" + Style.RESET_ALL)
print(separator)
print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |")
print(separator)
for rel, passed in results.items(): for rel, passed in results.items():
status = Fore.GREEN + "PASS" if passed else Fore.RED + "FAIL" plain_status = "PASS" if passed else "FAIL"
print(f" {rel}: {status}" + Style.RESET_ALL) status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \
print(Style.BRIGHT + f"\n {n_passed}/{n_total} passed." + Style.RESET_ALL) Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
print(f"| {rel.ljust(path_width)} | {status} |")
print(separator)
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
sys.exit(0 if n_passed == n_total else 1) sys.exit(0 if n_passed == n_total else 1)

View File

@@ -2,16 +2,114 @@ import argparse
import json import json
import numpy as np import numpy as np
import subprocess import subprocess
import shutil
import sys
from pathlib import Path from pathlib import Path
from colorama import Style, Fore from colorama import Style, Fore
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
from raptor import compile_with_raptor from raptor import compile_with_raptor
from gen_network_runner import gen_network_runner from gen_network_runner import gen_network_runner
from subprocess_utils import run_command_with_reporter
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir): STAGE_COUNT = 6
subprocess.run([raptor_path, network_onnx_path, "--EmitONNXIR"], check=True)
subprocess.run([raptor_path, network_onnx_path], check=True)
class ProgressReporter:
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
self.total_models = total_models
self.stages_per_model = stages_per_model
self.total_steps = max(1, total_models * stages_per_model)
self.completed_steps = 0
self.current_label = ""
self.enabled = True
self.columns = shutil.get_terminal_size((100, 20)).columns
self.suspended = False
def _clear(self):
if self.enabled:
sys.stdout.write("\033[2K\r")
def _render(self):
if not self.enabled or self.suspended:
return
bar_width = 24
filled = int(bar_width * self.completed_steps / self.total_steps)
prefix_text = f"[{'#' * filled}{'-' * (bar_width - filled)}] {self.completed_steps}/{self.total_steps}"
if len(prefix_text) > self.columns:
prefix_text = f"{self.completed_steps}/{self.total_steps}"
label = f" {self.current_label}" if self.current_label else ""
available_label_width = max(0, self.columns - len(prefix_text))
label = label[:available_label_width]
if prefix_text.startswith("["):
bar = Fore.GREEN + ("#" * filled) + Fore.CYAN + ("-" * (bar_width - filled))
prefix = Fore.CYAN + f"[{bar}{Fore.CYAN}] {self.completed_steps}/{self.total_steps}" + Style.RESET_ALL
else:
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
sys.stdout.write("\r" + prefix + label + Style.RESET_ALL)
sys.stdout.flush()
def log(self, message="", color=None):
if self.enabled:
self._clear()
if color:
print(color + message + Style.RESET_ALL)
else:
print(message)
self._render()
def set_stage(self, model_index, model_total, model_name, stage_name):
self.current_label = f"[{model_index}/{model_total}] {model_name} · {stage_name}"
self._render()
def advance(self):
self.completed_steps = min(self.total_steps, self.completed_steps + 1)
self._render()
def suspend(self):
self.suspended = True
self._clear()
def resume(self):
self.suspended = False
self._render()
def finish(self):
if self.enabled:
self.suspended = True
self._clear()
sys.stdout.flush()
def run_command(cmd, cwd=None, reporter=None):
run_command_with_reporter(cmd, cwd=cwd, reporter=reporter)
def print_stage(reporter, model_index, model_total, model_name, title):
stage_colors = {
"Compile ONNX": Fore.BLUE,
"Build Runner": Fore.MAGENTA,
"Generate Inputs": Fore.YELLOW,
"Run Reference": Fore.GREEN,
"Compile PIM": Fore.CYAN,
"Run Simulator": Fore.MAGENTA,
"Compare Outputs": Fore.YELLOW,
}
color = stage_colors.get(title, Fore.WHITE)
reporter.log(Style.BRIGHT + color + f"[{title}]" + Style.RESET_ALL)
reporter.set_stage(model_index, model_total, model_name, title)
def print_info(reporter, message):
reporter.log(f" {message}")
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=None):
run_command([raptor_path, network_onnx_path, "--EmitONNXIR"], reporter=reporter)
run_command([raptor_path, network_onnx_path], reporter=reporter)
parent = network_onnx_path.parent parent = network_onnx_path.parent
stem = network_onnx_path.stem stem = network_onnx_path.stem
so_path = parent / f"{stem}.so" so_path = parent / f"{stem}.so"
@@ -25,9 +123,9 @@ def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir)
return moved_so, moved_mlir return moved_so, moved_mlir
def build_onnx_runner(source_dir, build_dir): def build_onnx_runner(source_dir, build_dir, reporter=None):
subprocess.run(["cmake", source_dir], cwd=build_dir, check=True) run_command(["cmake", source_dir], cwd=build_dir, reporter=reporter)
subprocess.run(["cmake", "--build", ".", "-j"], cwd=build_dir, check=True) run_command(["cmake", "--build", ".", "-j"], cwd=build_dir, reporter=reporter)
return build_dir / "runner" return build_dir / "runner"
@@ -41,11 +139,12 @@ def build_dump_ranges(config_path, outputs_descriptor):
return ",".join(ranges) return ",".join(ranges)
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges): def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
subprocess.run( run_command(
["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--", ["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges], "-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
cwd=simulator_dir, check=True cwd=simulator_dir,
reporter=reporter,
) )
@@ -64,24 +163,41 @@ def parse_pim_simulator_outputs(output_bin_path, outputs_descriptor):
def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3): def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3):
all_passed = True all_passed = True
rows = []
for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor): for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor):
csv_name = f"output{oi}_{name}.csv" csv_name = f"output{oi}_{name}.csv"
runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape) runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape)
max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64)))) max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64))))
passed = max_diff <= threshold passed = max_diff <= threshold
status = Fore.GREEN + "[PASS]" if passed else Fore.RED + "[FAIL]" rows.append((name, f"{max_diff:.6e}", passed))
print(f" {name}: max diff = {max_diff:.6e} {status}" + Style.RESET_ALL)
if not passed: if not passed:
all_passed = False all_passed = False
name_width = max(len("Output"), *(len(name) for name, _, _ in rows))
diff_width = max(len("Max diff"), *(len(diff) for _, diff, _ in rows))
result_width = len("Result")
separator = f" +-{'-' * name_width}-+-{'-' * diff_width}-+-{'-' * result_width}-+"
print(separator)
print(f" | {'Output'.ljust(name_width)} | {'Max diff'.ljust(diff_width)} | {'Result'} |")
print(separator)
for name, diff_text, passed in rows:
status_text = ("PASS" if passed else "FAIL").ljust(result_width)
status = Fore.GREEN + status_text + Style.RESET_ALL if passed else Fore.RED + status_text + Style.RESET_ALL
print(f" | {name.ljust(name_width)} | {diff_text.ljust(diff_width)} | {status} |")
print(separator)
return all_passed return all_passed
def validate_network(network_onnx_path, raptor_path, onnx_include_dir, def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3): simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3,
reporter=None, model_index=1, model_total=1):
network_onnx_path = Path(network_onnx_path).resolve() network_onnx_path = Path(network_onnx_path).resolve()
raptor_path = Path(raptor_path).resolve() raptor_path = Path(raptor_path).resolve()
onnx_include_dir = Path(onnx_include_dir).resolve() onnx_include_dir = Path(onnx_include_dir).resolve()
simulator_dir = Path(simulator_dir).resolve() simulator_dir = Path(simulator_dir).resolve()
owns_reporter = reporter is None
reporter = reporter or ProgressReporter(model_total)
workspace_dir = network_onnx_path.parent workspace_dir = network_onnx_path.parent
raptor_dir = workspace_dir / "raptor" raptor_dir = workspace_dir / "raptor"
@@ -90,40 +206,72 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
Path.mkdir(raptor_dir, exist_ok=True) Path.mkdir(raptor_dir, exist_ok=True)
Path.mkdir(runner_build_dir, parents=True, exist_ok=True) Path.mkdir(runner_build_dir, parents=True, exist_ok=True)
print(Style.BRIGHT + "\nCompiling the onnx network:" + Style.RESET_ALL) reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
network_so_path, network_mlir_path = compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir) f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
print(Style.BRIGHT + "\nGenerating and building the runner:" + Style.RESET_ALL) try:
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c") print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
runner_path = build_onnx_runner(runner_dir, runner_build_dir) network_so_path, network_mlir_path = compile_onnx_network(
network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=reporter)
print_info(reporter, f"MLIR saved to {network_mlir_path}")
print_info(reporter, f"Shared library saved to {network_so_path}")
reporter.advance()
print(Style.BRIGHT + "\nGenerating random inputs:" + Style.RESET_ALL) print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
inputs_descriptor, outputs_descriptor = onnx_io(network_onnx_path) gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", verbose=False)
inputs_list, _inputs_dict = gen_random_inputs(inputs_descriptor) runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter)
flags, _files = save_inputs_to_files(network_onnx_path, inputs_list, out_dir=workspace_dir / "inputs") print_info(reporter, f"Runner built at {runner_path}")
reporter.advance()
print(Style.BRIGHT + "\nRunning inference with the runner:" + Style.RESET_ALL) print_stage(reporter, model_index, model_total, network_onnx_path.name, "Generate Inputs")
out_dir = workspace_dir / "outputs" inputs_descriptor, outputs_descriptor = onnx_io(network_onnx_path)
Path.mkdir(out_dir, exist_ok=True) inputs_list, _inputs_dict = gen_random_inputs(inputs_descriptor)
run_cmd = [runner_path, *flags] flags, _files = save_inputs_to_files(network_onnx_path, inputs_list, out_dir=workspace_dir / "inputs")
run_cmd += ["--save-csv-dir", f"{out_dir}"] print_info(reporter, f"Saved {len(inputs_list)} input file(s) to {workspace_dir / 'inputs'}")
subprocess.run(run_cmd, cwd=runner_build_dir, check=True) reporter.advance()
print(Style.BRIGHT + "\nCompiling for PIM with Raptor:" + Style.RESET_ALL) print_stage(reporter, model_index, model_total, network_onnx_path.name, "Run Reference")
compile_with_raptor(network_mlir_path, raptor_path, crossbar_size, crossbar_count) out_dir = workspace_dir / "outputs"
Path.mkdir(out_dir, exist_ok=True)
run_cmd = [runner_path, *flags]
run_cmd += ["--save-csv-dir", f"{out_dir}"]
run_command(run_cmd, cwd=runner_build_dir, reporter=reporter)
print_info(reporter, f"Reference outputs saved to {out_dir}")
reporter.advance()
print(Style.BRIGHT + "\nRunning PIM simulation:" + Style.RESET_ALL) print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
pim_dir = raptor_dir / "pim" compile_with_raptor(
write_inputs_to_memory_bin(pim_dir / "memory.bin", pim_dir / "config.json", inputs_list) network_mlir_path, raptor_path, crossbar_size, crossbar_count, reporter=reporter)
simulation_dir = workspace_dir / "simulation" print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
Path.mkdir(simulation_dir, exist_ok=True) reporter.advance()
dump_ranges = build_dump_ranges(pim_dir / "config.json", outputs_descriptor)
output_bin_path = simulation_dir / "out.bin"
run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges)
print(Style.BRIGHT + "\nValidating the results:" + Style.RESET_ALL) print_stage(reporter, model_index, model_total, network_onnx_path.name, "Run Simulator")
sim_arrays = parse_pim_simulator_outputs(output_bin_path, outputs_descriptor) pim_dir = raptor_dir / "pim"
return validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold) write_inputs_to_memory_bin(pim_dir / "memory.bin", pim_dir / "config.json", inputs_list)
simulation_dir = workspace_dir / "simulation"
Path.mkdir(simulation_dir, exist_ok=True)
dump_ranges = build_dump_ranges(pim_dir / "config.json", outputs_descriptor)
output_bin_path = simulation_dir / "out.bin"
run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=reporter)
print_info(reporter, f"Simulator output saved to {output_bin_path}")
reporter.advance()
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compare Outputs")
sim_arrays = parse_pim_simulator_outputs(output_bin_path, outputs_descriptor)
reporter.suspend()
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
reporter.resume()
reporter.advance()
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
return passed
except Exception:
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
raise
finally:
reporter.log("=" * 72)
if owns_reporter:
reporter.finish()
if __name__ == '__main__': if __name__ == '__main__':