add constant folding and verification pass for pim host operations

better validation scripts output
big refactors
This commit is contained in:
NiccoloN
2026-03-20 12:08:12 +01:00
parent 4e50e056e3
commit 6e1de865bb
64 changed files with 1364 additions and 2265 deletions

View File

@@ -20,6 +20,8 @@ add_onnx_mlir_library(OMPIMAccel
Pass/CountInstructionPass.cpp
Pass/EmitPimJsonPass.cpp
Pass/MessagePass.cpp
Pass/PimFoldHostConstantsPass.cpp
Pass/PimHostVerificationPass.cpp
EXCLUDE_FROM_OM_LIBS
@@ -43,4 +45,5 @@ add_onnx_mlir_library(OMPIMAccel
OMSpatialToGraphviz
OMSpatialToPIM
OMPIMCommon
MLIRTensorInferTypeOpInterfaceImpl
)

View File

@@ -5,6 +5,7 @@
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
@@ -30,6 +31,60 @@ void dumpModule(ModuleOp moduleOp, const std::string& name) {
file.close();
}
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
if (!moduleOp)
return failure();
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ")
<< entryPoints.size();
return failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return failure();
}
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
return mainGraphFunc;
SmallVector<func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
}
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return failure();
}
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PIM_WEIGHT_ALWAYS_ATTR_NAME) != nullptr; }
void markWeightAlways(Operation* op) {
assert(op && "expected valid op");
op->setAttr(PIM_WEIGHT_ALWAYS_ATTR_NAME, UnitAttr::get(op->getContext()));
}
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
}
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();

View File

@@ -1,6 +1,8 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
@@ -9,6 +11,7 @@
#include "src/Compiler/CompilerOptions.hpp"
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate";
inline constexpr llvm::StringRef PIM_WEIGHT_ALWAYS_ATTR_NAME = "weightAlways";
namespace onnx_mlir {
@@ -18,6 +21,14 @@ void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op);
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);

View File

@@ -13,11 +13,12 @@
#include <cassert>
#include <cmath>
#include "Common/PIMCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp"
@@ -49,8 +50,8 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!getGlobalOp->hasAttr("weightAlways")) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
auto iter = globalConstants.find(globalMemrefOp);
if (iter == globalConstants.end())
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
@@ -81,7 +82,7 @@ MemEntry PimMemory::getMemEntry(mlir::Value value) const {
return iter->second;
}
PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second;
}
@@ -112,10 +113,33 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
}
value = source;
}
else if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
}
else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
}
else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
}
else
break;
}
return memEntriesMap.at(value).address + offset;
auto iter = memEntriesMap.find(value);
if (iter == memEntriesMap.end()) {
errs() << "Missing mem entry for value: ";
value.print(errs());
errs() << "\n";
if (auto* definingOp = value.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Missing mem entry");
}
return iter->second.address + offset;
}
json::Object PimCodeGen::createEmptyOffset() {
@@ -348,6 +372,55 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co
}
}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
auto srcAddr = memory.getValueAddress(transposeOp.getData());
auto dstAddr = memory.getValueAddress(transposeOp.getOutBuf());
auto srcType = cast<ShapedType>(transposeOp.getData().getType());
auto srcShape = srcType.getShape();
size_t rank = srcShape.size();
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
size_t totalElements = srcType.getNumElements();
// Read permutation and compute its inverse
SmallVector<int64_t> perm =
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
SmallVector<int64_t> permInv(rank);
for (size_t i = 0; i < rank; i++)
permInv[perm[i]] = i;
// Destination shape: dstShape[i] = srcShape[perm[i]]
SmallVector<int64_t> dstShape(rank);
for (size_t i = 0; i < rank; i++)
dstShape[i] = srcShape[perm[i]];
// Row-major strides for source and destination
SmallVector<size_t> srcStrides(rank, 1);
SmallVector<size_t> dstStrides(rank, 1);
for (int64_t i = rank - 2; i >= 0; i--) {
srcStrides[i] = srcStrides[i + 1] * srcShape[i + 1];
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
}
// Emit element-by-element copy with transposed addressing
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
// Decompose flat source index into multi-dimensional index
SmallVector<size_t> srcIdx(rank);
size_t remaining = srcFlat;
for (size_t d = 0; d < rank; d++) {
srcIdx[d] = remaining / srcStrides[d];
remaining %= srcStrides[d];
}
// Compute flat destination index: dstIdx[d] = srcIdx[permInv[d]]
size_t dstFlat = 0;
for (size_t d = 0; d < rank; d++)
dstFlat += srcIdx[permInv[d]] * dstStrides[d];
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
}
}
size_t getMatrixSize(ShapedType matrixShape) {
if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4)
assert(false && "Unsupported matrix shape");
@@ -378,9 +451,9 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (getGlobalOp->hasAttr("weightAlways"))
if (hasWeightAlways(getGlobalOp))
return;
auto globalOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp)
return;
auto initialValue = globalOp.getInitialValue();
@@ -416,7 +489,7 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
size_t processedOperations = 0;
for (auto& op : coreOp.getBody().front()) {
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp>(op))
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp, memref::ExpandShapeOp, memref::CollapseShapeOp>(op))
continue;
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
@@ -435,6 +508,8 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
else if (auto applyFiltersOp = dyn_cast<pim::PimApplyFiltersOp>(op))
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp);
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op))
coreCodeGen.codeGenVAddOp(vaddOp);
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op))
@@ -475,7 +550,7 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
continue;
}
auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(moduleOp, getGlobalOp.getNameAttr());
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
weightIndex++;
@@ -589,9 +664,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
}
}
auto funcOps = moduleOp.getOps<func::FuncOp>();
assert(!funcOps.empty() && "No function found in the module");
auto funcOp = *funcOps.begin();
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc))
return CompilerFailure;
auto funcOp = *entryFunc;
PimAcceleratorMemory memory;
memory.hostMem.allocateHost(moduleOp, funcOp);

View File

@@ -5,7 +5,7 @@
#include "Common/ValueMap.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
@@ -49,7 +49,7 @@ public:
PimAcceleratorMemory()
: hostMem(memEntriesMap) {}
PimMemory getOrCreateDeviceMem(size_t id);
PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value) const;
};
@@ -95,6 +95,7 @@ public:
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const;
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
};
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);

View File

@@ -25,7 +25,6 @@ extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<bool> exportCrossbarWeights;
extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore;

View File

@@ -2,7 +2,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp"
@@ -46,6 +46,10 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
}
if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimFoldHostConstantsPass());
pm.addPass(createMessagePass("Pim host constants folded"));
pm.addPass(createPimHostVerificationPass());
pm.addPass(createMessagePass("Pim host verified"));
pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim json code emitted"));

View File

@@ -3,21 +3,15 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen)
add_onnx_mlir_library(OMONNXToSpatial
Math/Gemm.hpp
Math/Gemm.cpp
Math/Conv.hpp
Math/Conv.cpp
Math/ExperimentalConv.cpp
Math/ExperimentalGemm.cpp
NN/Pooling.cpp
NN/ExperimentalPooling.cpp
NN/ReduceMean.cpp
Tensor/ONNXConcatToTensorConcat.cpp
Tensor/RemoveUnusedHelperOps.cpp
Utils/SpatialReducer.cpp
Utils/WeightSubdivider.cpp
Utils/AnnotateReplication.cpp
ONNXToSpatialPass.hpp
ONNXToSpatialPass.cpp
ONNXToSpatialCommon.cpp

View File

@@ -242,6 +242,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
return success();
}
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
} // namespace onnx_mlir

View File

@@ -18,6 +18,6 @@ struct ConvToGemm : mlir::OpConversionPattern<mlir::ONNXConvOp> {
mlir::ConversionPatternRewriter& rewriter) const override;
};
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir

View File

@@ -1,583 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
#include <cstddef>
#include <memory>
#include <unordered_map>
#include <vector>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
namespace onnx_mlir {
// NOTE:
// This might be useful to re-implement this considering for loops.
// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
/**
* @brief A momentary representation of a core, to be used within the tiling of
* a convolution operation.
*/
class Core {
public:
Core(const size_t coreId, ConversionPatternRewriter& rewriter)
: coreId(coreId), rewriter(rewriter) {}
/**
* @brief Add a MVM operation to the core.
*
* @param inputTile The input tile to the MVM operation.
* @param xbarIndex The index of the crossbar weight to use.
* @param outputTileId The id of the output tile.
* @param mvmOutType The result's shape.
* @return Value The result of the MVM operation.
*/
Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
// Use the inputTile as the reference location for the MVM operation.
Location loc = inputTile.getLoc();
// Move the insertion point to the end of the block.
rewriter.setInsertionPointToEnd(block.get());
// Add the inputTile to the block arguments, and to the operands.
Value operand = operandMap.lookupOrNull(inputTile);
if (not operand) {
operand = block->addArgument(inputTile.getType(), loc);
operands.push_back(inputTile);
operandMap.map(inputTile, operand);
}
// TODO: Compute the output type using the matrix, and check if `mvmOutType`
// is correct.
// Construct the MVM operation
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(loc, mvmOutType, xbarIndex, operand);
// Since we are within the same core and no computation can happen in
// paralllel, we can just apply a linear reduction in case we have multiple
// MVM operations for the same outputTile.
auto lastMVM = outputTileToMVM.find(outputTileId);
// If an entry for this outputTile already exists, apply reduction.
if (lastMVM != outputTileToMVM.end()) {
// MVM results should have the same type for reduction.
assert(lastMVM->second.getType() == result.getType());
result = rewriter.create<spatial::SpatVAddOp>(loc, mvmOutType, lastMVM->second, result);
}
outputTileToMVM[outputTileId] = result;
return result;
}
/**
* @brief Mark a result as remappable, and return a shared pointer to it.
*
* This function marks a result as remappable, and returns a shared pointer to
* it. We need to keep track of these values to generate the YieldOp at a
* later stage.
*
* @param result A result to track, for later remapping.
* @return shared_ptr<Value> A shared pointer to the result.
*/
shared_ptr<Value> makeResultRemappable(Value result) {
// Verify that the result is present in the block.
assert(result.getDefiningOp()->getBlock() == block.get());
shared_ptr<mlir::Value> remappableResult = make_shared<Value>(result);
resultsToRemap.push_back(remappableResult);
results.push_back(result);
return remappableResult;
}
/**
* @brief Add a remappable operand to the core, to merge partial results
* inter-core.
*
* @param remappableOperand The operand to add.
* @return Value The block argument representing the operand.
*/
Value addRemappableOperand(std::shared_ptr<Value> operand) {
// Check that the operand is not already there.
assert(not operandMap.contains(*operand));
Value argument = block->addArgument(operand->getType(), operand->getLoc());
remappableOperands.push_back(operand);
return argument;
}
/**
* @brief Generate a spatial::SpatWeightedCompute operation from the core.
*
* @param loc The location of the operation.
* @return spatial::SpatWeightedCompute
*/
spatial::SpatWeightedCompute createWComputeOp(Location loc) {
// Get the shape of the results.
SmallVector<Type> resultTypes;
for (const auto& value : results)
resultTypes.push_back(value.getType());
// Create the WComputeOp, with non-remappable operands only.
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(loc, resultTypes, xbarWeights, operands);
// Add the body to the WComputeOp.
Block* releasedBlock = block.release();
wcomputeOp.getBody().push_back(releasedBlock);
// Add the `yieldOp` at the end, with the results.
rewriter.setInsertionPointToEnd(releasedBlock);
rewriter.create<spatial::SpatYieldOp>(loc, results);
return wcomputeOp;
}
/**
* @brief Remap the results to the WComputeOp results.
*/
void remapResults() {
// Remap all the results to the WComputeOp results.
assert(resultsToRemap.size() == wcomputeOp->getNumResults());
for (size_t i = 0; i < resultsToRemap.size(); i++)
*resultsToRemap[i] = wcomputeOp.getResult(i);
}
void addRemappedOperands() {
// Insert the remappableOperands (which were remapped in
// `addRemappableOperand` of another Core)
for (auto remappedValue : remappableOperands)
wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue);
// Update the wcomputeOp operandSegmentSize
incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast<int>(remappableOperands.size()));
}
size_t addXbarWeight(Value weight) {
assert(!isXbarsFull());
xbarWeights.push_back(weight);
return xbarWeights.size() - 1;
}
bool isXbarsFull() {
assert(xbarWeights.size() <= crossbarCountInCore);
return xbarWeights.size() == crossbarCountInCore;
}
bool isCoreEmpty() { return block->empty(); }
void dump() {
// Print the coreId
llvm::outs() << "Core " << coreId << ":\n";
// Print the weights
llvm::outs() << "Xbar Weights:\n";
for (auto weight : xbarWeights)
weight.dump();
// Print the operands
llvm::outs() << "Operands:\n";
for (auto operand : operands)
llvm::outs() << operand << "\n";
// Dump the body block
for (auto& op : block->getOperations())
op.dump();
// Print the results
llvm::outs() << "Results:\n";
for (auto result : results)
llvm::outs() << result << "\n";
}
const size_t coreId;
private:
ConversionPatternRewriter& rewriter;
// Should these be set<Value> instead? But I need to keep the order
vector<Value> operands;
vector<std::shared_ptr<Value>> remappableOperands;
vector<Value> results;
vector<std::shared_ptr<Value>> resultsToRemap;
// Maps from input tiles to the block operand
IRMapping operandMap;
// Map from outputTileId to MVM operation producing it
unordered_map<size_t, Value> outputTileToMVM;
vector<Value> xbarWeights;
unique_ptr<mlir::Block> block = make_unique<Block>();
spatial::SpatWeightedCompute wcomputeOp;
};
struct ConvToManyGemms : public OpConversionPattern<ONNXConvOp> {
ConvToManyGemms(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
struct Producer_t {
Value value;
shared_ptr<Core> core;
};
LogicalResult
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType());
ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType());
ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType());
ShapedType yShape = mlir::cast<ShapedType>(conv.getY().getType());
size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y;
unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y);
auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(conv, padUnpackError.value());
// TODO: Pad value at beginning and end of each dimension could be
// different. We should handle this case.
// MapOperations mapOperation = MapOperations::None;
//
// // If we have just one user, and it is an activation funcion (or more in
// // general a mapping operation) just inline it in the computeOps
// auto firstUserOp = *conv->getUsers().begin();
// if (conv->hasOneUse()) {
// mapOperation = mlirOpToMapOperationEnum(firstUserOp);
//
// if (mapOperation == MapOperations::ONNXSoftmaxOp) {
// return rewriter.notifyMatchFailure(
// conv, "Softmax not supported as activation for convolutions.");
// }
// }
size_t input_h = GET_IMAGE_HEIGHT(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
size_t krn_w = GET_KERNEL_WIDTH(wShape);
Location loc = conv.getLoc();
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize;
// Tile the input tensor
// Input tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: inputTiles[channelTile][x][y]
// Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
inputTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt = resolveImgInputTiles(
convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
if (resolveErrorOpt.has_value())
return rewriter.notifyMatchFailure(conv, *resolveErrorOpt);
SmallVector<OpFoldResult> strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult> {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
// Tile the weight tensor
// Weight tiles need to be indexed by:
// a. Filter Tile
// b. Channel Tile
// c. Kernel `x` position
// d. Kernel `y` position
// For example: weightTiles[filterTile][channelTile][x][y]
// Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH)
SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles(
outputTileCount,
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
sizes = {rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
for (size_t i = 0; i < outputTileCount; i++) {
if (i == outputTileCount - 1 && outputTileRemainder != 0)
sizes[0] = rewriter.getIndexAttr(outputTileRemainder);
sizes[1] = rewriter.getIndexAttr(crossbarSize);
offsets[0] = rewriter.getIndexAttr(i * crossbarSize);
for (size_t j = 0; j < inputTileCount; j++) {
if (j == inputTileCount - 1 && inputTileRemainder != 0)
sizes[1] = rewriter.getIndexAttr(inputTileRemainder);
for (size_t x = 0; x < krn_w; x++) {
for (size_t y = 0; y < krn_h; y++) {
offsets[1] = rewriter.getIndexAttr(j * crossbarSize);
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
weightTiles[i][j][x][y] =
rewriter.create<tensor::ExtractSliceOp>(loc, convAdaptor.getW(), offsets, sizes, strides);
}
}
}
}
/* Distribute the computation among many compute cores
* Try to compute in-core the computation for each output tile, and reduce
* over as few cores as possible
*/
// Tile the output tensor
// Output tiles need to be indexed by:
// a. Filter Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: outputTiles[filterTile][x][y]
// Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH)
SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles(
outputTileCount,
SmallVector<SmallVector<shared_ptr<Value>>>(output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
size_t replicationFactor;
if (!conv->hasAttr(REPLICATION_ATTR_NAME))
replicationFactor = 1;
else
replicationFactor = conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
// producers[outTile][out_x][out_y][producerIndex]
vector<vector<vector<vector<Producer_t>>>> producers = vector<vector<vector<vector<Producer_t>>>>(
outputTileCount,
vector<vector<vector<Producer_t>>>(output_w, vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
// Schedule in cores
size_t coreId = 0;
vector<shared_ptr<Core>> curCores(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++)
curCores[i] = make_shared<Core>(coreId++, rewriter);
vector<shared_ptr<Core>> cores;
const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor);
for (size_t krn_x = 0; krn_x < krn_h; krn_x++) {
for (size_t krn_y = 0; krn_y < krn_w; krn_y++) {
RankedTensorType mvmOutType =
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1}, bShape.getElementType());
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
if (outTile == outputTileCount - 1 && outputTileRemainder != 0)
mvmOutType = mvmOutType.clone({1, static_cast<long>(outputTileRemainder), 1, 1});
for (size_t inTile = 0; inTile < inputTileCount; inTile++) {
vector<size_t> xbarIndexes(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++)
xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]);
size_t out_x = 0;
for (size_t in_x = 0; in_x < input_w; in_x += stride_x) {
size_t out_y = 0;
// I use `replicationFactor` cores. I divide the input_w into
// `replicationFactor` slices, and each slice is distributed to a
// core. `coreIndex` is the index of the core that will be used
// for this slice
size_t coreIndex = in_x / replicationSliceSize;
assert(coreIndex < replicationFactor);
for (size_t in_y = 0; in_y < input_h; in_y += stride_y) {
// Adjust the input based on the kernel
int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x;
int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y;
// Check if we are within the input image
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) {
out_y++;
continue;
}
size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y;
auto mvm = curCores[coreIndex]->addMVM(
inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType);
producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]});
out_y++;
}
out_x++;
}
// Computations for these crossbars are done, check if the cores
// crossbars are fully used. If full, swap with new core
for (size_t i = 0; i < replicationFactor; i++) {
if (curCores[i]->isXbarsFull()) {
cores.emplace_back(std::move(curCores[i]));
curCores[i] = make_shared<Core>(coreId++, rewriter);
}
}
}
}
}
}
for (auto& curCore : curCores)
if (curCore->isCoreEmpty() == false)
cores.emplace_back(std::move(curCore));
curCores.clear();
// Now, do the reduction of each output pixel tile
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
for (size_t out_x = 0; out_x < output_w; out_x++) {
for (size_t out_y = 0; out_y < output_h; out_y++) {
// First, check if some producers are within the same core. If this is
// true, `Core::addMVM` have already done the reduction within-core.
// This means that we only need to consider the last producer for that
// core.
std::unordered_map<size_t, Producer_t> withinCoreReducedProducers;
for (auto producer : producers[outTile][out_x][out_y])
withinCoreReducedProducers[producer.core->coreId] = producer;
// Now, we need to apply inter-core reduction
// Base case with one producer
if (withinCoreReducedProducers.size() == 1) {
// TODO: Add the bias and apply mapping (if present)
auto singleProducer = withinCoreReducedProducers.begin()->second;
// Use last producer as the final result
auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue;
continue;
}
// TODO: This is a linear reduction, not a tree reduction. We can do
// better: a tree reduction would make more computations happen in
// parallel.
Producer_t lastProducer = withinCoreReducedProducers.begin()->second;
auto it = withinCoreReducedProducers.begin();
it++;
while (it != withinCoreReducedProducers.end()) {
Producer_t curProducer = it->second;
shared_ptr<Core> core1;
shared_ptr<Core> core2;
Value core1Value;
Value core2Value;
auto lastProducerCoreId = lastProducer.core->coreId;
auto curProducerCoreId = curProducer.core->coreId;
assert(lastProducerCoreId != curProducerCoreId
&& "We should have already applied within-core reduction, how "
"could we have same cores here?");
// Sort the cores by coreId
if (curProducerCoreId < lastProducerCoreId) {
core1 = curProducer.core;
core1Value = curProducer.value;
core2 = lastProducer.core;
core2Value = lastProducer.value;
}
else {
core1 = lastProducer.core;
core1Value = lastProducer.value;
core2 = curProducer.core;
core2Value = curProducer.value;
}
auto newCoreRes = core1->makeResultRemappable(core1Value);
auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes);
rewriter.setInsertionPointAfterValue(core2Value);
Value vaddRes = rewriter.create<spatial::SpatVAddOp>(
core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg);
lastProducer = {vaddRes, core2};
it++;
}
// TODO: Add the bias and apply mapping (if present)
// Use last producer as the final result
auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue;
}
}
}
// Now, we need to turn the cores into a spatial::SpatWeightedCompute.
rewriter.setInsertionPointAfter(conv);
spatial::SpatWeightedCompute lastWComputeOp;
for (auto& core : cores) {
lastWComputeOp = core->createWComputeOp(loc);
core->remapResults();
rewriter.setInsertionPointAfter(lastWComputeOp);
}
for (auto& core : cores)
core->addRemappedOperands();
// Set the insertion point after the last WComputeOp.
rewriter.setInsertionPointAfter(lastWComputeOp);
SmallVector<Value> tilesToConcat;
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize);
for (size_t outX = 0; outX < output_h; outX++)
for (size_t outY = 0; outY < output_w; outY++)
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
tilesToConcat.push_back(*outputTiles[outTile][outX][outY]);
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(loc, conv.getY().getType(), tilesToConcat);
// Value outputImage =
// createImgConcatOp(outputTiles, rewriter, loc, Y.getType());
// If no mapping (activation) was applied, just replace ConvOp
// if (mapOperation == MapOperations::None) {
// rewriter.replaceOp(conv, outputImage);
// } else {
// // If mapping was applied, erase ConvOp and replace the mapping op
// rewriter.eraseOp(conv);
// rewriter.replaceOp(firstUserOp, outputImage);
// }
return success();
}
};
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ConvToManyGemms>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,400 +0,0 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cstddef>
#include <unistd.h>
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
namespace onnx_mlir {
/**
* @brief A pattern to tile the convolution operation into a series of compute
* units, each one of which applies filters to a subset of the input
* tensor. Results are also reduced and concatenated to form the final
* output tensor.
*/
struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ExperimentalONNXConvOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
// --------------------------------- //
// --- READ OPERATION PARAMETERS --- //
// --------------------------------- //
// To get each crossbar's weights, we need to slice the weights tensor.
// - Along the input tiles.
// - Along the output tiles.
// - Along the filter x position.
// - Along the filter y position.
ShapedType inputType = cast<ShapedType>(convAdaptor.getX().getType());
ShapedType outputType = cast<ShapedType>(conv.getY().getType());
ShapedType weightsType = cast<ShapedType>(convAdaptor.getW().getType());
// TODO: Address bigger batches.
assert(GET_IMAGE_N(inputType) == 1
&& "Batch size must be 1"
"for convolution.");
// TODO: Address replication.
assert(coresCount.getValue() == -1 && "Replication is not yet supported for convolution.");
// TODO: Address bias addition.
ldiv_t inputTileCount = div(GET_IMAGE_CHANNEL(inputType), crossbarSize);
ldiv_t outputTileCount = div(GET_IMAGE_CHANNEL(outputType), crossbarSize);
size_t kernelWidth = GET_KERNEL_WIDTH(weightsType);
size_t kernelHeight = GET_KERNEL_HEIGHT(weightsType);
// Assert that the kernel is square.
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
// -------------------------------- //
// --- SLICE THE WEIGHTS TENSOR --- //
// -------------------------------- //
// The core idea of this stage is classifying the weights by input and
// output tile. This is because we want the applyFilters operations to be
// tile agnostic, to keep the subsequent lowering stages as simple as
// possible. This data structure does this weight classification:
// - The outer map is indexed by input tile.
// - The inner map is indexed by output tile.
// - The SmallVector contains the weights for the filter.
map<long, map<long, SmallVector<Value>>> weightsGroups;
// During all slicing operations within this stage, we'll use the same
// strides for all dimensions.
SmallVector<OpFoldResult> slicingStrides(4, rewriter.getIndexAttr(1));
ldiv_t itc = inputTileCount;
ldiv_t otc = outputTileCount;
// - Slicing along the input tiles.
// - Slicing along the output tiles.
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
// The loop above also sets the crossbar's used width and height,
// checking if we're at the last crossbar and if it's incomplete.
long outputTile = ot;
long inputTile = it;
// Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
// - Slicing along the filter x position.
// - Slicing along the filter y position.
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(filterX),
/* 3 */ rewriter.getIndexAttr(filterY)};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
}
}
}
}
// TODO: Tree reduction for compute reduction should be implemented.
// -------------------------------- //
// --- CREATE ALL COMPUTE UNITS --- //
// -------------------------------- //
// Keep track of input slicing operations to avoid duplication across
// all compute units (global slices).
map<long, Value> globalSlices;
// Keep track of all partial compute results.
map<long, Value> globalPartialResults;
// Use a weight subdivider to extract groups of weights for each compute
// unit. We'll keep extracting groups until no more weights are left.
WeightSubdivider weightSubdivider(weightsGroups);
while (!weightSubdivider.isEmpty()) {
// -------------------------------- //
// --- BEGIN A NEW COMPUTE UNIT --- //
// -------------------------------- //
// Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands;
// ------------------------------ //
// --- SLICE THE INPUT TENSOR --- //
// ------------------------------ //
// Note each tile's index in the compute unit arguments.
map<long, size_t> inputTileIndices;
map<long, size_t> outputTileIndices;
map<long, size_t> reductionTileIndices; // Incoming partial results.
// Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) {
for (Value weight : group.weights)
computeWeights.push_back(weight);
// There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end())
continue;
// We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a
// new one.
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
computeOperands.push_back(globalSlices[group.inputTile]);
localSlices[group.inputTile] = globalSlices[group.inputTile];
continue;
}
// Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(0),
/* 3 */ rewriter.getIndexAttr(0)};
// Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
/* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, slicingStrides);
computeOperands.push_back(extractSliceOp);
// Update slicing maps.
globalSlices[group.inputTile] = extractSliceOp;
localSlices[group.inputTile] = extractSliceOp;
// Update the input tile index.
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
}
// ------------------------------- //
// --- PREPARE THE OUTPUT TYPE --- //
// ------------------------------- //
// Fill the compute output's type by looking at the output tiles.
SmallVector<Type> computeOutputType;
for (TaggedWeights group : weightsGroups) {
// There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
continue;
// Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
}
// Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
// TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed.
/* 3 */ GET_IMAGE_HEIGHT(outputType)};
auto elementType = dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
}
// ----------------------------- //
// --- FILL THE COMPUTE UNIT --- //
// ----------------------------- //
// Create the compute unit.
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
conv.getLoc(), computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands.
Block* block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands)
block->addArgument(operand.getType(), conv->getLoc());
// Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices)
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
// Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group.
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation.
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights.
SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i)
weightIndices.push_back(group.startingCrossbarIndex + i);
SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos;
for (auto weight : group.weights) {
// Assert that the weight is an extract_slice operation.
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
assert(extractSliceOp && "Weight is not an extract_slice operation.");
// Get the filter x and y positions from the extract_slice operation.
auto offsets = extractSliceOp.getStaticOffsets();
xKerPos.push_back(offsets[2]);
yKerPos.push_back(offsets[3]);
}
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
conv.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
// Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(
conv.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
}
// Update the partial results map.
localPartialResults[group.outputTile] = result;
}
// Add a yield operation to the block by concatenating the partial
// results.
SmallVector<Value> applyFiltersResults;
for (size_t i = 0; i < computeOutputType.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
// Get that tile's partial result and add it to the list.
applyFiltersResults.push_back(localPartialResults[outputTile]);
}
// Create the yield operation with the given results.
rewriter.create<spatial::SpatYieldOp>(conv.getLoc(), applyFiltersResults);
// Update the global partial results map.
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
globalPartialResults[outputTile] = currentCompute.getResult(i);
}
// Move the rewrite cursor out of the block.
rewriter.setInsertionPointAfter(currentCompute);
}
// ------------------------------ //
// --- CONCATENATE THE OUTPUT --- //
// ------------------------------ //
// Turn the values into a SmallVector.
SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
outputValues.push_back(globalPartialResults[i]);
// Assert that the number of output values is correct.
assert(outputValues.size() > 0 && "No output values were generated for the convolution.");
// If the conv's user is a ReLU...
if (conv->hasOneUse()) {
Operation* user = *conv->getUsers().begin();
if (auto relu = dyn_cast<ONNXReluOp>(user)) {
// ...then we can just replace the ReLU with the concatenation.
rewriter.replaceOp(relu, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
// And erase the convolution.
rewriter.eraseOp(conv);
return success();
}
}
// Return the final output.
rewriter.replaceOp(conv, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
return success();
}
};
/**
* @brief Populate the tiling pattern for a convolution operation.
*
* @param patterns The pattern set to populate.
* @param ctx The MLIR context.
*/
void populateExperimentalTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ExperimentalONNXConvOpTile>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,365 +0,0 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdlib>
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
namespace onnx_mlir {
struct ExperimentalGemmConversionPattern : public OpConversionPattern<ONNXGemmOp> {
ExperimentalGemmConversionPattern(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
// --------------------------------- //
// --- READ OPERATION PARAMETERS --- //
// --------------------------------- //
// To get each crossbar's weights, we need to slice the weights tensor.
// - Along the input tiles.
// - Along the output tiles.
// - Along the filter x position.
// - Along the filter y position.
ShapedType inputType = cast<ShapedType>(adaptor.getA().getType());
ShapedType outputType = cast<ShapedType>(gemmOp.getY().getType());
ShapedType matrixType = cast<ShapedType>(adaptor.getB().getType());
// TODO: Address bigger batches.
assert(inputType.getShape()[0] == 1 && "Only batch size of 1 is supported for GEMM.");
// TODO: Address replication.
assert(coresCount.getValue() == -1 && "Replication is not yet supported for GEMM.");
// TODO: Address bias addition.
assert(inputType.getShape()[1] == matrixType.getShape()[0] && "Input tile size must match the matrix's row size.");
ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize);
ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize);
size_t kernelWidth = 1;
size_t kernelHeight = 1;
// Assert that the kernel is square.
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
// -------------------------------- //
// --- SLICE THE WEIGHTS TENSOR --- //
// -------------------------------- //
// The core idea of this stage is classifying the weights by input and
// output tile. This is because we want the applyFilters operations to be
// tile agnostic, to keep the subsequent lowering stages as simple as
// possible. This data structure does this weight classification:
// - The outer map is indexed by input tile.
// - The inner map is indexed by output tile.
// - The SmallVector contains the weights for the filter.
map<long, map<long, SmallVector<Value>>> weightsGroups;
// During all slicing operations within this stage, we'll use the same
// strides for all dimensions.
SmallVector<OpFoldResult> slicingStrides(2, rewriter.getIndexAttr(1));
ldiv_t itc = inputTileCount;
ldiv_t otc = outputTileCount;
// - Slicing along the input tiles.
// - Slicing along the output tiles.
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
// The loop above also sets the crossbar's used width and height,
// checking if we're at the last crossbar and if it's incomplete.
long outputTile = ot;
long inputTile = it;
// Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ /* rewriter.getIndexAttr(1), */
/* 3 */ /* rewriter.getIndexAttr(1) */};
// - Slicing along the filter x position.
// - Slicing along the filter y position.
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(filterX), */
/* 3 */ /* rewriter.getIndexAttr(filterY) */};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
}
}
}
}
// TODO: Tree reduction for compute reduction should be implemented.
// -------------------------------- //
// --- CREATE ALL COMPUTE UNITS --- //
// -------------------------------- //
// Keep track of input slicing operations to avoid duplication across
// all compute units (global slices).
map<long, Value> globalSlices;
// Keep track of all partial compute results.
map<long, Value> globalPartialResults;
// Use a weight subdivider to extract groups of weights for each compute
// unit. We'll keep extracting groups until no more weights are left.
WeightSubdivider weightSubdivider(weightsGroups);
while (!weightSubdivider.isEmpty()) {
// -------------------------------- //
// --- BEGIN A NEW COMPUTE UNIT --- //
// -------------------------------- //
// Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands;
// ------------------------------ //
// --- SLICE THE INPUT TENSOR --- //
// ------------------------------ //
// Note each tile's index in the compute unit arguments.
map<long, size_t> inputTileIndices;
map<long, size_t> outputTileIndices;
map<long, size_t> reductionTileIndices; // Incoming partial results.
// Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) {
for (Value weight : group.weights)
computeWeights.push_back(weight);
// There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end())
continue;
// We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a
// new one.
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
computeOperands.push_back(globalSlices[group.inputTile]);
localSlices[group.inputTile] = globalSlices[group.inputTile];
continue;
}
// Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(0), */
/* 3 */ /* rewriter.getIndexAttr(0) */};
// Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
/* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
gemmOp.getLoc(), adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
computeOperands.push_back(extractSliceOp);
// Update slicing maps.
globalSlices[group.inputTile] = extractSliceOp;
localSlices[group.inputTile] = extractSliceOp;
// Update the input tile index.
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
}
// ------------------------------- //
// --- PREPARE THE OUTPUT TYPE --- //
// ------------------------------- //
// Fill the compute output's type by looking at the output tiles.
SmallVector<Type> computeOutputType;
for (TaggedWeights group : weightsGroups) {
// There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
continue;
// Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
}
// Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
// TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed.
/* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType()).getElementType();
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
}
// ----------------------------- //
// --- FILL THE COMPUTE UNIT --- //
// ----------------------------- //
// Create the compute unit.
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
gemmOp.getLoc(), computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands.
Block* block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands)
block->addArgument(operand.getType(), gemmOp->getLoc());
// Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices)
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
// Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group.
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation.
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights.
SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i)
weightIndices.push_back(group.startingCrossbarIndex + i);
SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos;
for (auto weight : group.weights) {
// Assert that the weight is an extract_slice operation.
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
assert(extractSliceOp && "Weight is not an extract_slice operation.");
// Get the filter x and y positions from the extract_slice operation.
xKerPos.push_back(0);
yKerPos.push_back(0);
}
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
gemmOp.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
// Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(
gemmOp.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
}
// Update the partial results map.
localPartialResults[group.outputTile] = result;
}
// Add a yield operation to the block by concatenating the partial
// results.
SmallVector<Value> applyFiltersResults;
for (size_t i = 0; i < computeOutputType.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
// Get that tile's partial result and add it to the list.
applyFiltersResults.push_back(localPartialResults[outputTile]);
}
// Create the yield operation with the given results.
rewriter.create<spatial::SpatYieldOp>(gemmOp.getLoc(), applyFiltersResults);
// Update the global partial results map.
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
globalPartialResults[outputTile] = currentCompute.getResult(i);
}
// Move the rewrite cursor out of the block.
rewriter.setInsertionPointAfter(currentCompute);
}
// ------------------------------ //
// --- CONCATENATE THE OUTPUT --- //
// ------------------------------ //
// Turn the values into a SmallVector.
SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
outputValues.push_back(globalPartialResults[i]);
// Assert that the number of output values is correct.
assert(outputValues.size() > 0 && "No output values were generated for the GEMM operation.");
// Return the final output.
rewriter.replaceOp(gemmOp, rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
return success();
}
};
void populateGemmToConvConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ExperimentalGemmConversionPattern>(ctx);
}
} // namespace onnx_mlir

View File

@@ -10,7 +10,6 @@
#include <cassert>
#include "Gemm.hpp"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
@@ -20,6 +19,38 @@
using namespace mlir;
namespace onnx_mlir {
namespace {
constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
GemvToSpatialCompute(MLIRContext* ctx)
: OpConversionPattern(ctx, 1) {}
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
private:
static Value resolveONNXExpOpFromUseChain(Value startValue);
static LogicalResult softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
Value& softmaxChannel,
ConversionPatternRewriter& rewriter,
SpatialReducer& reducer,
ONNXGemmOp& gemmOp,
Location& loc);
};
} // namespace
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,

View File

@@ -1,54 +0,0 @@
#pragma once
#include "Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
constexpr mlir::StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
struct GemmToManyGemv : mlir::OpConversionPattern<mlir::ONNXGemmOp> {
GemmToManyGemv(mlir::MLIRContext* ctx)
: OpConversionPattern(ctx, 2) {}
mlir::LogicalResult matchAndRewrite(mlir::ONNXGemmOp gemmOp,
mlir::ONNXGemmOpAdaptor gemmOpAdaptor,
mlir::ConversionPatternRewriter& rewriter) const override;
};
struct GemvToSpatialCompute : mlir::OpConversionPattern<mlir::ONNXGemmOp> {
GemvToSpatialCompute(mlir::MLIRContext* ctx)
: OpConversionPattern(ctx, 1) {}
llvm::LogicalResult matchAndRewrite(mlir::ONNXGemmOp gemmOp,
mlir::ONNXGemmOpAdaptor gemmOpAdaptor,
mlir::ConversionPatternRewriter& rewriter) const override;
private:
/**
* Resolves the ONNXExpOp from the use chain of the given start value.
*
* This function traverses the use chain of the start value until it finds an
* ONNXExpOp. It returns the value of the ONNXExpOp.
*
* @param startValue The starting value of the use chain.
* @return The value of the ONNXExpOp found in the use chain.
*/
static mlir::Value resolveONNXExpOpFromUseChain(mlir::Value startValue);
// Softmax is a special case, as it requires another reduction after the
// first one. In the cores, `applyReducePattern` already applied
// f(x) = exp(x) to each tile. This mean that now we just need to
// reduce-sum these tiles, and then divide each tile by the reduced sum,
// which is propagated back to the cores via a broadcast channel.
static llvm::LogicalResult softmaxReductionApplication(llvm::SmallVector<OpAndResNum>& outputOpsAndResNums,
Value& softmaxChannel,
ConversionPatternRewriter& rewriter,
SpatialReducer& reducer,
ONNXGemmOp& gemmOp,
Location& loc);
};
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx);
} // namespace onnx_mlir

View File

@@ -1,300 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
template <typename PoolOp>
bool hasPostProcessExperimentalPoolingWindow() {
return false;
}
template <>
bool hasPostProcessExperimentalPoolingWindow<ONNXAveragePoolOp>() {
return true;
}
template <typename PoolOp>
Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr;
}
template <>
Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
Location loc,
ONNXAveragePoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
}
template <typename ReductionOp>
Value reduceInputTiles(SmallVector<Value>& inputTiles, ConversionPatternRewriter& rewriter) {
if (inputTiles.size() == 1)
return inputTiles[0];
if (inputTiles.size() == 2) {
return rewriter.create<spatial::SpatVMaxOp>(
inputTiles[0].getLoc(), inputTiles[0].getType(), inputTiles[0], inputTiles[1]);
}
SmallVector<Value> left(inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2);
SmallVector<Value> right(inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
Value leftReduced = reduceInputTiles<ReductionOp>(left, rewriter);
Value rightReduced = reduceInputTiles<ReductionOp>(right, rewriter);
return rewriter.create<ReductionOp>(inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
}
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
ExperimentalPoolingBaseConverter(MLIRContext* ctx)
: OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult();
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET")
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
size_t pad_x, pad_y;
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
Location loc = poolOp.getLoc();
size_t input_h = GET_IMAGE_HEIGHT(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
ldiv_t tileCount = std::div(GET_IMAGE_CHANNEL(xShape), crossbarSize);
// Assert that the input is a tensor.ConcatOp.
auto concat = X.getDefiningOp<tensor::ConcatOp>();
if (!concat)
return rewriter.notifyMatchFailure(poolOp, "Expected input to be a tensor.ConcatOp");
// Create a [channel_tile][x][y] array to store the input tiles.
std::map<long, std::map<long, std::map<long, Value>>> inputTiles;
// For each argument of the tensor.ConcatOp, resolve the input tiles.
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) {
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = {/* 0 */ rewriter.getIndexAttr(0),
/* 1 */ rewriter.getIndexAttr(0),
/* 2 */ rewriter.getIndexAttr(x),
/* 3 */ rewriter.getIndexAttr(y)};
SmallVector<OpFoldResult> sizes = {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
// Get the concat's operand that we want to slice.
Value concatInput = concat.getOperand(it);
Value slicedTile = rewriter.create<tensor::ExtractSliceOp>(loc, concatInput, offsets, sizes, strides);
inputTiles[it][x][y] = slicedTile;
}
}
}
// Prepare the shape of the compute's output.
ldiv_t itc = tileCount;
SmallVector<Type> outputTileTypes;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */
cast<RankedTensorType>(inputTiles[it][0][0].getType()).getShape()[1],
/* 2 */ 1,
/* 3 */ 1};
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
outputTileTypes.push_back(RankedTensorType::get(outputShapeArray, elementType));
}
}
}
// Create a plain value list of the input tiles.
SmallVector<Value> inputTilesList;
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x)
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it)
inputTilesList.push_back(inputTiles[it][y][x]);
}
// Create a single compute to calculate the output.
auto computeOp =
rewriter.create<spatial::SpatWeightedCompute>(loc, outputTileTypes, SmallVector<Value>(), inputTilesList);
// Create a new block for the compute unit and add the operands.
Block* block = rewriter.createBlock(&computeOp.getRegion());
// Fill the block arguments and keep a reference to them.
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> inputTilesArgs;
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
inputTilesArgs[it][y][x] = block->addArgument(computeOp->getOperand(tileIndex).getType(), loc);
}
}
}
// Begin writing in the block.
rewriter.setInsertionPointToStart(block);
// Go through all pooling blocks.
SmallVector<Value> outputTiles;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
size_t start_x = x * stride_x;
size_t start_y = y * stride_y;
size_t end_x = std::min(start_x + krn_w, input_w);
size_t end_y = std::min(start_y + krn_h, input_h);
SmallVector<Value> inputTilesToReduce;
for (size_t ky = start_y; ky < end_y; ++ky)
for (size_t kx = start_x; kx < end_x; ++kx)
inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]);
auto reduceResult = reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
// If the reduce op is add, we need to divide the result by the
// number of elements in the pooling window.
if (hasPostProcessExperimentalPoolingWindow<PoolOp>()) {
// Add a spat.const before the computeOp.
rewriter.setInsertionPoint(computeOp);
auto divisorValue =
rewriter.create<spatial::SpatConstantOp>(loc,
RankedTensorType::get({1}, rewriter.getF32Type()),
rewriter.getI64IntegerAttr(krn_w * krn_h),
rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfter(reduceResult.getDefiningOp());
reduceResult =
rewriter.create<spatial::SpatVSDivOp>(loc, reduceResult.getType(), reduceResult, divisorValue);
}
outputTiles.push_back(reduceResult);
}
}
}
// Create a YieldOp to return the output tiles.
rewriter.create<spatial::SpatYieldOp>(loc, outputTiles);
// Set the rewrite cursor right after the computeOp.
rewriter.setInsertionPointAfter(computeOp);
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> computeOutput;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
computeOutput[it][y][x] = computeOp.getResult(tileIndex);
}
}
}
// We'll now create spat.img.concat ops to concatenate the output tiles.
SmallVector<Value> outputTilesList;
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<Value> imgConcatTiles;
for (size_t y = 0; y < output_h; ++y)
for (size_t x = 0; x < output_w; ++x)
imgConcatTiles.push_back(computeOutput[it][y][x]);
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ (long) tilingSize,
/* 2 */ (long) output_w,
/* 3 */ (long) output_h};
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(
loc, RankedTensorType::get(outputShapeArray, elementType), imgConcatTiles));
}
// Create a new tensor.ConcatOp to concatenate the output tiles.
Value outputTensor = rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
rewriter.replaceOp(poolOp, outputTensor);
return success();
}
};
void populateExperimentalPoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<
ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(
ctx);
}
} // namespace onnx_mlir

View File

@@ -26,8 +26,6 @@ using namespace mlir;
namespace onnx_mlir {
llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
ConversionPatternRewriter& rewriter,
std::function<Value(const Value&, const Value&)> reduce,
@@ -225,12 +223,12 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
Location loc = poolOp.getLoc();
size_t input_h = GET_IMAGE_HEIGHT(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
size_t channelTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
size_t input_h = getImageHeight(xShape);
size_t input_w = getImageWidth(xShape);
size_t output_h = getImageHeight(yShape);
size_t output_w = getImageWidth(yShape);
size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
size_t channelTileRest = getImageChannel(xShape) % crossbarSize;
// 1: Tile the input tensor
// Input tiles need to be indexed by:

View File

@@ -13,9 +13,7 @@ def onnxToArithConstantOp : Pat<
(Arith_ConstantOp $value)
>;
//===----------------------------------------------------------------------===//
// ONNXMatMulOp to ONNXGemmOp patterns
//===----------------------------------------------------------------------===//
def matMulAddToGemmPattern : Pat<
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
@@ -39,9 +37,7 @@ def matMulToGemmPattern : Pat<
)
>;
//===----------------------------------------------------------------------===//
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
//===----------------------------------------------------------------------===//
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
// ONNXConvOp with a bias.
@@ -55,9 +51,7 @@ def convAddToConvWithBiasPatternRight : Pat<
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
>;
//===----------------------------------------------------------------------===//
// Operation to ignore (i.e. remove)
//===----------------------------------------------------------------------===//
def replaceWithOperationOfValue : NativeCodeCall<"$0">;

View File

@@ -180,10 +180,10 @@ void tileImageTensorByChannel(Value imageTensor,
ConversionPatternRewriter& rewriter) {
ShapedType imageShape = mlir::cast<ShapedType>(imageTensor.getType());
size_t input_h = GET_IMAGE_HEIGHT(imageShape);
size_t input_w = GET_IMAGE_WIDTH(imageShape);
size_t tileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(imageShape), tileSize);
size_t tileRest = GET_IMAGE_CHANNEL(imageShape) % tileSize;
size_t input_h = getImageHeight(imageShape);
size_t input_w = getImageWidth(imageShape);
size_t tileCount = ceilIntegerDivide(getImageChannel(imageShape), tileSize);
size_t tileRest = getImageChannel(imageShape) % tileSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));

View File

@@ -9,24 +9,55 @@
#include "llvm/Support/LogicalResult.h"
#include <cassert>
#include <cstddef>
#include <optional>
#include <type_traits>
#include <utility>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#define DEFINE_MAP_OP(opname) opname,
#define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2)
#define GET_IMAGE_HEIGHT(shapedType) shapedType.getDimSize(3)
#define GET_IMAGE_CHANNEL(shapedType) shapedType.getDimSize(1)
#define GET_IMAGE_N(shapedType) shapedType.getDimSize(0)
#define GET_KERNEL_WIDTH(shapedType) shapedType.getDimSize(2)
#define GET_KERNEL_HEIGHT(shapedType) shapedType.getDimSize(3)
#define GET_FILTER_COUNT(shapedType) shapedType.getDimSize(0)
using namespace mlir;
namespace onnx_mlir {
const StringRef REPLICATION_ATTR_NAME = "replication_factor";
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
inline constexpr mlir::StringRef REPLICATION_ATTR_NAME = "replication_factor";
using HSliceId = size_t;
using CoreId = size_t;
@@ -58,51 +89,64 @@ constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
}
template <class T>
bool isVectorShape(const ArrayRef<T> shape) {
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(const ArrayRef<T> shape) {
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(const ArrayRef<T> shape) {
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(const ArrayRef<T> shape) {
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(const ArrayRef<T> shape) {
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(const Value tensor) { return cast<RankedTensorType>(tensor.getType()).getShape(); }
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc);
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
SmallVector<Value>
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
DenseMap<CoreId, SmallVector<Value>>
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc);
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc);
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
tensor::SplatOp
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter);
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input);
mlir::Value createMapOperation(mlir::PatternRewriter& rewriter, MapOperations mapOp, const mlir::Value& input);
/**
* Unpacks an optional pair vector into two size_t values.
@@ -126,7 +170,8 @@ void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t
*
* @return llvm::Optional<llvm::Twine> The error message if the pads are invalid
*/
std::optional<Twine> unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y);
std::optional<llvm::Twine>
unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y);
/**
* Tiles the image tensor by channel.
@@ -140,10 +185,10 @@ std::optional<Twine> unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> val
* @param tileSize The size of each tile.
* @param rewriter The ConversionPatternRewriter used for creating operations.
*/
void tileImageTensorByChannel(Value imageTensor,
SmallVector<SmallVector<SmallVector<Value>>>& tiles,
void tileImageTensorByChannel(mlir::Value imageTensor,
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& tiles,
size_t tileSize,
ConversionPatternRewriter& rewriter);
mlir::ConversionPatternRewriter& rewriter);
/**
* Creates an ImgConcatOp based on the given tiles.
@@ -159,10 +204,10 @@ void tileImageTensorByChannel(Value imageTensor,
*
* @return The created ImgConcatOp.
*/
Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTiles,
ConversionPatternRewriter& rewriter,
Location& loc,
Type outputType);
mlir::Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& outputTiles,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc,
mlir::Type outputType);
/**
* @brief Verifies if the given input coordinates and padding values are within
@@ -177,7 +222,7 @@ Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTile
* @return LogicalResult Returns success if the coordinates and padding are
* within bounds, failure otherwise.
*/
LogicalResult
mlir::LogicalResult
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y);
/**
@@ -207,8 +252,9 @@ verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY,
* @return std::optional<llvm::Twine> An error message if the input tensor could
* not be resolved into tiles.
*/
std::optional<Twine> resolveImgInputTiles(Value wholeInputTensor,
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
std::optional<llvm::Twine>
resolveImgInputTiles(mlir::Value wholeInputTensor,
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& inputTiles,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
@@ -258,6 +304,6 @@ void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcom
* @return The index of the result of the operation that produces the specified
* value.
*/
int getResultIndex(Operation* op, Value v);
int getResultIndex(mlir::Operation* op, mlir::Value v);
}; // namespace onnx_mlir

View File

@@ -1,3 +1,4 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
@@ -10,19 +11,39 @@
#include "Common/PIMCommon.hpp"
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "Math/Conv.hpp"
#include "ONNXToSpatialPass.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
bool haveSameStaticShape(Value lhs, Value rhs);
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
ONNXToSpatialPass() = default;
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
};
} // namespace
void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
@@ -40,15 +61,19 @@ void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
IRRewriter rewriter(moduleOp);
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
if (annotateReplication(funcOp, rewriter).failed()) {
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
signalPassFailure();
return;
}
if (annotateReplication(*entryFunc, rewriter).failed()) {
llvm::dbgs() << "Failed during annotation for replication analysis\n";
signalPassFailure();
return;
}
ConversionTarget target(*ctx);
target.addLegalDialect<ONNXDialect, SpatialDialect, tensor::TensorDialect, arith::ArithDialect, tosa::TosaDialect>();
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>();
@@ -62,16 +87,9 @@ void ONNXToSpatialPass::runOnOperation() {
RewritePatternSet patterns(ctx);
patterns.add<removeLRNPattern>(ctx);
if (useExperimentalConvImpl) {
populateExperimentalTilingConvOpPattern(patterns, ctx);
populateExperimentalPoolingTilingPattern(patterns, ctx);
populateGemmToConvConversionPattern(patterns, ctx);
}
else {
populateTilingConvOpPattern(patterns, ctx);
populateConvOpPatterns(patterns, ctx);
populatePoolingTilingPattern(patterns, ctx);
populateOnnxGemmOpPatterns(patterns, ctx);
}
populateONNXConcatToTensorConcatPattern(patterns, ctx);
populateReduceMeanConversionPattern(patterns, ctx);
@@ -84,8 +102,8 @@ void ONNXToSpatialPass::runOnOperation() {
// Count the number of compute ops and check they do not exceed the core count
if (coresCount != -1) {
int computeOpsCount = 0;
for (auto& op : funcOp.getFunctionBody().front().getOperations())
if (isa<SpatWeightedCompute>(op))
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<spatial::SpatWeightedCompute>(op))
computeOpsCount++;
if (computeOpsCount > coresCount) {
@@ -102,22 +120,21 @@ void ONNXToSpatialPass::runOnOperation() {
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
annotateWeightsConstants(funcOp);
annotateWeightsConstants(*entryFunc);
// Dump to file for debug
dumpModule(moduleOp, "spatial");
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight =
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<SpatWeightedCompute>(user); });
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
if (isAlwaysWeight)
constantOp->setAttr("weightAlways", UnitAttr::get(ctx));
markWeightAlways(constantOp);
});
}
} // namespace spatial
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
} // namespace onnx_mlir

View File

@@ -1,34 +0,0 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
using namespace mlir;
extern bool haveSameStaticShape(Value lhs, Value rhs);
namespace spatial {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
ONNXToSpatialPass() = default;
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
};
} // namespace spatial
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<spatial::ONNXToSpatialPass>(); }
} // namespace onnx_mlir

View File

@@ -1,27 +1,20 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir {
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateDistributeReducePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateFoldComputePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
// Experimental patterns.
void populateExperimentalTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmToConvConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateExperimentalPoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir

View File

@@ -10,7 +10,7 @@ using namespace mlir;
namespace onnx_mlir {
template <typename OpTy, typename OpAdaptorTy>
struct RemoveUnusedHelperOps : public OpRewritePattern<OpTy> {
struct RemoveUnusedHelperOps : OpRewritePattern<OpTy> {
RemoveUnusedHelperOps(MLIRContext* ctx)
: OpRewritePattern<OpTy>(ctx) {}

View File

@@ -49,11 +49,11 @@ LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& r
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
ShapedType wShape = mlir::cast<ShapedType>(W.getType());
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
size_t krn_w = GET_KERNEL_WIDTH(wShape);
size_t input_w = getImageWidth(xShape);
size_t krn_h = getKernelHeight(wShape);
size_t krn_w = getKernelWidth(wShape);
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t inputTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
size_t outputTileCount = ceilIntegerDivide(wShape.getDimSize(0), crossbarSize.getValue());
auto neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;

View File

@@ -15,21 +15,21 @@
namespace onnx_mlir {
llvm::SmallPtrSet<Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
llvm::SmallPtrSet<mlir::Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value&)> processFun,
ConversionPatternRewriter& rewriter) {
std::function<mlir::Value(const mlir::Value&)> processFun,
mlir::ConversionPatternRewriter& rewriter) {
assert(processFun);
auto computeOp = GET_COMP(computeOpAndResNum);
auto resultNum = GET_RES_NUM(computeOpAndResNum);
spatial::SpatYieldOp yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
spatial::SpatYieldOp yieldOp = mlir::cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value result = yieldOp->getOperand(resultNum);
mlir::Value result = yieldOp->getOperand(resultNum);
rewriter.setInsertionPointAfterValue(result);
Value processedResult = processFun(result);
mlir::Value processedResult = processFun(result);
if (processedResult == result) {
// Sometimes we want processedResult to return the same value but do
// something else with it (e.g. in softmax we want to broadcast the value
@@ -42,10 +42,11 @@ ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum
return yieldOp.getNumOperands() - 1;
}
OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value&)> preprocess,
std::function<Value(const Value&)> postprocess) {
OpAndResNum
SpatialReducer::applyReducePattern(llvm::SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<mlir::Value(const mlir::Value&, const mlir::Value&)> reduce,
std::function<mlir::Value(const mlir::Value&)> preprocess,
std::function<mlir::Value(const mlir::Value&)> postprocess) {
if (preprocess)
for (auto& computeOpAndResNum : computeOpsAndResNum)
@@ -55,18 +56,18 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
// computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation*, Value> lastValueForCompute;
std::unordered_map<mlir::Operation*, mlir::Value> lastValueForCompute;
for (auto& computeOpAndResNum : computeOpsAndResNum) {
auto computeOp = GET_COMP(computeOpAndResNum);
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
auto yieldOp = mlir::cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
mlir::Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
auto it = lastValueForCompute.find(computeOp.getOperation());
if (it != lastValueForCompute.end()) {
// If we have already seen this computeOp, apply the reduction
// within-compute
Value lastWithinComputeValue = it->second;
mlir::Value lastWithinComputeValue = it->second;
assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp());
@@ -85,12 +86,12 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
computeOpsAndResNum.clear();
computeOpsAndResNum.reserve(lastValueForCompute.size());
for (auto& entry : lastValueForCompute) {
auto computeOp = cast<spatial::SpatWeightedCompute>(entry.first);
auto computeOp = mlir::cast<spatial::SpatWeightedCompute>(entry.first);
auto valueWithinCompute = entry.second;
// We check if `valueWithinCompute` is already used by the yieldOp, in that
// case no need to add it
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
auto yieldOp = mlir::cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
bool yieldOpUseFound = false;
for (auto& use : valueWithinCompute.getUses()) {
if (use.getOwner() == yieldOp.getOperation()) {
@@ -110,7 +111,7 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
computeOpsAndResNum.push_back({computeOp, resultNum});
}
Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc();
mlir::Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc();
// Recursive algorithm to reduce the inputs to a single one:
// - Take two inputs at a time, and reduce them into a single one, updating
@@ -118,7 +119,7 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
// - Repeat until there is only one input left.
llvm::OwningArrayRef<ComputeAndResNum> computeOpsRef(computeOpsAndResNum);
while (computeOpsRef.size() > 1) {
SmallVector<ComputeAndResNum> nextComputeOps;
llvm::SmallVector<ComputeAndResNum> nextComputeOps;
nextComputeOps.reserve(computeOpsRef.size() / 2);
for (size_t i = 0; i < computeOpsRef.size() - 1; i += 2) {
auto [firstCompute, firstResultNum] = computeOpsRef[i];
@@ -135,23 +136,23 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& co
// the number of results)
// See below `reducerChanges.push_back` and `finalizeReduceUpdates`
auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator());
auto yieldOpFirstCompute = mlir::cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator());
// Add a new operand to the block of the second computeOp
Block& secondBlock = secondCompute.getBody().front();
Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
mlir::Block& secondBlock = secondCompute.getBody().front();
mlir::Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
auto secondComputeWeightsNum =
secondCompute->getAttrOfType<DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0];
secondCompute->getAttrOfType<mlir::DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0];
auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
// Take the "former-result" from the second computeOp
spatial::SpatYieldOp secondYield = cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
Value formerRes2 = secondYield.getOperand(secondResultNum);
spatial::SpatYieldOp secondYield = mlir::cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
mlir::Value formerRes2 = secondYield.getOperand(secondResultNum);
// Apply reduction operation
rewriter.setInsertionPoint(secondYield);
Value reduced = reduce(formerRes2, formerRes1);
mlir::Value reduced = reduce(formerRes2, formerRes1);
// Unfortunately, it is not possible to update the result in place,
// because we may have already referenced it by <computeOp, resultNum>
@@ -219,7 +220,7 @@ void SpatialReducer::finalizeReduceUpdates() {
// `opToReplacedCompute`
auto toComputeOp = opToReplacedCompute[toOp];
if (!toComputeOp)
toComputeOp = cast<spatial::SpatWeightedCompute>(toOp);
toComputeOp = mlir::cast<spatial::SpatWeightedCompute>(toOp);
assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!");
@@ -234,31 +235,31 @@ void SpatialReducer::finalizeReduceUpdates() {
}
}
Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) {
mlir::Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) {
assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates.");
Operation* opToCast;
mlir::Operation* opToCast;
auto it = opToReplacedCompute.find(opAndResNum.first);
if (it != opToReplacedCompute.end())
opToCast = it->second;
else
opToCast = opAndResNum.first;
auto computeOp = cast<spatial::SpatWeightedCompute>(opToCast);
auto computeOp = mlir::cast<spatial::SpatWeightedCompute>(opToCast);
return computeOp.getResult(opAndResNum.second);
}
void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
void SpatialReducer::updateResultsOfCompute(mlir::Operation* computeOp) {
if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) {
// If we have already replaced the fromOp, we do not need to do it again
return;
}
auto oldComputeOp = cast<spatial::SpatWeightedCompute>(computeOp);
auto oldComputeOp = mlir::cast<spatial::SpatWeightedCompute>(computeOp);
auto oldComputeOpNum = oldComputeOp->getNumOperands();
auto yieldOp = cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
auto yieldOp = mlir::cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) {
// No result was added, just add itself to the map
@@ -283,8 +284,8 @@ void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
// Since we replaced the old ComputeOp with a new one, we need to replace
// all its results' uses
for (size_t i = 0; i < oldComputeOp.getNumResults(); i++) {
Value oldResult = oldComputeOp.getResult(i);
Value newResult = newComputeOp.getResult(i);
mlir::Value oldResult = oldComputeOp.getResult(i);
mlir::Value newResult = newComputeOp.getResult(i);
// Replace the uses, except the uses of the compute ops which got deleted
// previously
@@ -298,9 +299,10 @@ void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
rewriter.eraseOp(oldComputeOp);
}
Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAndResNum>>>& outputTiles,
Location& loc,
Type outputType) {
mlir::Value
SpatialReducer::createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
mlir::Location& loc,
mlir::Type outputType) {
assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates.");
@@ -309,8 +311,8 @@ Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAn
auto width = outputTiles[0].size();
auto height = outputTiles[0][0].size();
SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles(
tilesCount, SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height)));
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>> remappedOutputTiles(
tilesCount, llvm::SmallVector<llvm::SmallVector<mlir::Value>>(width, llvm::SmallVector<mlir::Value>(height)));
for (size_t t = 0; t < tilesCount; t++)
for (size_t x = 0; x < width; x++)
@@ -320,16 +322,16 @@ Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAn
return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType);
}
OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
ConversionPatternRewriter& rewriter,
Value biasTile,
OpAndResNum SpatialReducer::applyAddMapReduction(llvm::SmallVector<ComputeAndResNum>& computeOps,
mlir::ConversionPatternRewriter& rewriter,
mlir::Value biasTile,
MapOperations mapOp) {
std::function<Value(const Value&)> postprocessing = nullptr;
std::function<mlir::Value(const mlir::Value&)> postprocessing = nullptr;
if (mapOp != MapOperations::None) {
postprocessing = [&](const Value a) {
Value mapOperand = a;
postprocessing = [&](const mlir::Value a) {
mlir::Value mapOperand = a;
if (biasTile)
mapOperand = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, biasTile);
return createMapOperation(rewriter, mapOp, mapOperand);
@@ -338,7 +340,7 @@ OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector<ComputeAndResNum>&
return this->applyReducePattern(
computeOps,
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); },
[&](mlir::Value a, mlir::Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); },
/* preprocess = */ nullptr,
postprocessing);
}

View File

@@ -3,6 +3,10 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include <functional>
#include <unordered_map>
#include <utility>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -13,28 +17,28 @@ using ResNum = unsigned int;
using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>;
struct SpatialReducerChange {
Operation* fromOp;
mlir::Operation* fromOp;
unsigned int fromOpResNum;
Operation* toOp;
mlir::Operation* toOp;
unsigned int toOpOperandNum;
};
using OpAndResNum = std::pair<Operation*, ResNum>;
using OpAndResNum = std::pair<mlir::Operation*, ResNum>;
class SpatialReducer {
public:
SpatialReducer(ConversionPatternRewriter& rewriter)
SpatialReducer(mlir::ConversionPatternRewriter& rewriter)
: rewriter(rewriter) {}
OpAndResNum applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value&)> preprocess,
std::function<Value(const Value&)> postprocess);
OpAndResNum applyReducePattern(llvm::SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<mlir::Value(const mlir::Value&, const mlir::Value&)> reduce,
std::function<mlir::Value(const mlir::Value&)> preprocess,
std::function<mlir::Value(const mlir::Value&)> postprocess);
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
ConversionPatternRewriter& rewriter,
Value biasTile,
OpAndResNum applyAddMapReduction(llvm::SmallVector<ComputeAndResNum>& computeOps,
mlir::ConversionPatternRewriter& rewriter,
mlir::Value biasTile,
MapOperations mapOp);
void finalizeReduceUpdates();
@@ -44,17 +48,17 @@ public:
finalizeReduceUpdates();
}
Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
Location& loc,
Type outputType);
mlir::Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
mlir::Location& loc,
mlir::Type outputType);
Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
mlir::Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
private:
[[nodiscard("computeOp result number gets updated")]] ResNum
applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value&)> processFun,
ConversionPatternRewriter& rewriter);
std::function<mlir::Value(const mlir::Value&)> processFun,
mlir::ConversionPatternRewriter& rewriter);
/**
* @brief Update the results of a ComputeOp.
@@ -66,19 +70,19 @@ private:
*
* @param computeOp The ComputeOp to update the results of.
*/
void updateResultsOfCompute(Operation* computeOp);
void updateResultsOfCompute(mlir::Operation* computeOp);
ConversionPatternRewriter& rewriter;
mlir::ConversionPatternRewriter& rewriter;
bool reducesFinalized = false;
// List of changes to be applied after the reduction is finalized
SmallVector<SpatialReducerChange, 4> reducerChanges;
llvm::SmallVector<SpatialReducerChange, 4> reducerChanges;
// List of computeOps that need to be replaced with new results
SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
llvm::SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
std::unordered_map<Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
std::unordered_map<mlir::Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
static llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
static llvm::SmallPtrSet<mlir::Operation*, 16> oldComputeOpsReplaced;
};
} // namespace onnx_mlir

View File

@@ -4,7 +4,7 @@
namespace onnx_mlir {
WeightSubdivider::WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights)
WeightSubdivider::WeightSubdivider(std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights)
: weights(std::move(weights)) {}
bool WeightSubdivider::isEmpty() const { return weights.empty(); }
@@ -13,7 +13,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
assert(!weights.empty() && "No weights to extract.");
auto it = weights.begin();
SmallVector<Value>& values = it->second.begin()->second;
llvm::SmallVector<mlir::Value>& values = it->second.begin()->second;
long inputTile = it->first;
long outputTile = it->second.begin()->first;
@@ -21,7 +21,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
size_t n = std::min(amount, values.size());
crossbarsUsed += n;
SmallVector<Value> result;
llvm::SmallVector<mlir::Value> result;
result.assign(values.begin(), values.begin() + n);
if (n < values.size()) {
@@ -36,9 +36,9 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
return {inputTile, outputTile, crossbarsUsed - n, result};
}
SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) {
llvm::SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) {
crossbarsUsed = 0;
SmallVector<TaggedWeights> result;
llvm::SmallVector<TaggedWeights> result;
size_t remaining = n;
while (remaining > 0 && !weights.empty()) {

View File

@@ -4,11 +4,9 @@
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <map>
using namespace mlir;
using namespace std;
namespace onnx_mlir {
/**
@@ -19,7 +17,7 @@ struct TaggedWeights {
long inputTile;
long outputTile;
size_t startingCrossbarIndex;
SmallVector<Value> weights;
llvm::SmallVector<mlir::Value> weights;
};
/**
@@ -33,16 +31,16 @@ struct TaggedWeights {
*/
class WeightSubdivider {
private:
map<long, map<long, SmallVector<Value>>> weights;
std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights;
size_t crossbarsUsed = 0;
TaggedWeights popGroup(size_t amount);
public:
WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights);
WeightSubdivider(std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights);
bool isEmpty() const;
SmallVector<TaggedWeights> popGroups(size_t n);
llvm::SmallVector<TaggedWeights> popGroups(size_t n);
};
} // namespace onnx_mlir

View File

@@ -10,6 +10,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -199,12 +200,12 @@ private:
void SpatialToGraphvizPass::runOnOperation() {
ModuleOp module = getOperation();
// Get the first OP, must be a FuncOp
func::FuncOp func = *module.getOps<func::FuncOp>().begin();
if (!func) {
module->emitError("No FuncOp found in the begin of module");
auto entryFunc = getPimEntryFunc(module);
if (failed(entryFunc)) {
signalPassFailure();
return;
}
func::FuncOp func = *entryFunc;
os << "digraph G {\n"
<< "\tnode [style=filled,color=white];\n";

View File

@@ -3,7 +3,6 @@ mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPIMIncGen)
add_onnx_mlir_library(OMSpatialToPIM
SpatialToPIMPass.hpp
SpatialToPIMPass.cpp
SpatialToPIMCommon.cpp

View File

@@ -3,10 +3,18 @@
#ifndef OP_BASE
include "mlir/IR/PatternBase.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "src/Dialect/ONNX/ONNX.td"
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/PIM/Pim.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE
def onnxToPimTransposeOp : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVMMOp : Pat<
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector,

View File

@@ -8,6 +8,7 @@
#include "SpatialToPIMCommon.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
@@ -90,8 +91,8 @@ Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Opera
auto resultShapedType = cast<ShapedType>(resultType);
rewriter.setInsertionPoint(operation);
return rewriter.create<tensor::EmptyOp>(
operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
return tensor::EmptyOp::create(
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
}
} // namespace onnx_mlir

View File

@@ -4,8 +4,6 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
/**
@@ -20,10 +18,10 @@ namespace onnx_mlir {
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
* \return The actual offset of the ExtractSliceOp.
*/
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape);
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
template <class T>
size_t rangeLength(const iterator_range<T> range) {
size_t rangeLength(const mlir::iterator_range<T> range) {
return std::distance(range.begin(), range.end());
}
@@ -35,16 +33,16 @@ size_t rangeLength(const iterator_range<T> range) {
* @return The earliest user operation that uses the given value within the
* current block.
*/
Operation* getEarliestUserWithinBlock(Value value);
mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation);
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation);
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
static bool isMemoryContiguous(const ArrayRef<int64_t> srcShape,
const ArrayRef<int64_t> offsets,
const ArrayRef<int64_t> sizes,
const ArrayRef<int64_t> strides) {
static bool isMemoryContiguous(const mlir::ArrayRef<int64_t> srcShape,
const mlir::ArrayRef<int64_t> offsets,
const mlir::ArrayRef<int64_t> sizes,
const mlir::ArrayRef<int64_t> strides) {
// Check that all strides are 1
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
@@ -99,10 +97,13 @@ static bool isMemoryContiguous(const ArrayRef<int64_t> srcShape,
return true;
}
inline tensor::EmptyOp createEmptyTensorFromShaped(IRRewriter& rewriter, Location loc, ShapedType shapedType) {
return rewriter.create<tensor::EmptyOp>(loc, shapedType.getShape(), shapedType.getElementType());
inline mlir::tensor::EmptyOp
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
}
inline bool isAConcatOp(Operation* op) { return isa<tensor::ConcatOp>(op) || isa<spatial::SpatImgConcatOp>(op); }
inline bool isAConcatOp(mlir::Operation* op) {
return isa<mlir::tensor::ConcatOp>(op) || isa<spatial::SpatImgConcatOp>(op);
}
} // namespace onnx_mlir

View File

@@ -17,12 +17,64 @@
#include <string>
#include <utility>
#include "SpatialToPIMPass.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
using namespace onnx_mlir;
using namespace pim;
using namespace spat_to_pim;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
struct SpatialToPIMPass : PassWrapper<SpatialToPIMPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPIMPass() = default;
SpatialToPIMPass(const SpatialToPIMPass& pass) {}
void runOnOperation() final;
private:
SmallVector<Value> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value& channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
};
} // namespace
void SpatialToPIMPass::runOnOperation() {
coreId = 1;
@@ -35,14 +87,17 @@ void SpatialToPIMPass::runOnOperation() {
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) {
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
if (!funcOp)
llvm_unreachable("No FuncOp found in the begin of module");
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
@@ -260,7 +315,7 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
for (auto returnValue : returnOp->getOperands()) {
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!returnValueDefiningOp->hasAttr("weightAlways"));
assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back(returnValue);
}
else {
@@ -487,3 +542,7 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
}
}
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<SpatialToPIMPass>(); }
} // namespace onnx_mlir

View File

@@ -1,60 +0,0 @@
#pragma once
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
namespace spat_to_pim {
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
struct SpatialToPIMPass : PassWrapper<SpatialToPIMPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPIMPass() = default;
SpatialToPIMPass(const SpatialToPIMPass& pass) {}
void runOnOperation() final;
private:
SmallVector<Value> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value& channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
};
} // namespace spat_to_pim
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<spat_to_pim::SpatialToPIMPass>(); }
} // namespace onnx_mlir

View File

@@ -1,2 +1,2 @@
add_subdirectory(PIM)
add_subdirectory(Pim)
add_subdirectory(Spatial)

View File

@@ -1,15 +0,0 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
void registerOpBufferizationInterfaces(DialectRegistry& registry);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -1,34 +0,0 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include "Dialect/PIM/PimOps.hpp"
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
namespace pim {
#include "Dialect/PIM/Transforms/Bufferization/PimBufferization.hpp.inc"
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
StringRef getArgument() const override { return "bufferize-pim"; }
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
PimBufferizationPass() = default;
PimBufferizationPass(const PimBufferizationPass& pass) {}
void runOnOperation() final;
private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
};
} // namespace pim
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<pim::PimBufferizationPass>(); }
} // namespace onnx_mlir

View File

@@ -175,7 +175,33 @@ def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
}];
}
// Computation
// Algebra
def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> {
let description = [{
Matrix transpose
}];
let arguments = (ins
PimTensor: $data,
I64ArrayAttr: $perms,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
let description = [{
@@ -197,6 +223,10 @@ def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {

View File

@@ -10,7 +10,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
@@ -20,7 +20,7 @@ namespace pim {
void PimDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
>();
}
@@ -45,5 +45,5 @@ POPULATE_DEPENDENCIES(PimVExpOp)
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"

View File

@@ -12,7 +12,7 @@
#include <string>
/// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc"
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.hpp.inc"
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp.inc"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp.inc"

View File

@@ -3,7 +3,6 @@ mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(PimBufferizationIncGen)
add_onnx_mlir_library(OMPimBufferization
PimBufferizationPass.hpp
PimBufferizationPass.cpp
OpBufferizationInterfaces.hpp
OpBufferizationInterfaces.cpp

View File

@@ -1,4 +1,4 @@
#include "Dialect/PIM/Transforms/Bufferization/Common.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
using namespace mlir;

View File

@@ -2,12 +2,10 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
IntegerAttr getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref);
mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -4,7 +4,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
using namespace bufferization;
@@ -76,6 +76,32 @@ struct MemCopyDevToHostOpInterface
}
};
struct TransposeOpBufferizeInterface
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto transposeOp = cast<PimTransposeOp>(op);
auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state);
if (failed(dataOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimTransposeOp>(
rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt);
return success();
}
};
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -176,6 +202,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);

View File

@@ -0,0 +1,13 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
namespace pim {
void registerOpBufferizationInterfaces(mlir::DialectRegistry& registry);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -4,7 +4,7 @@
#ifndef OP_BASE
include "mlir/IR/PatternBase.td"
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
include "src/Accelerators/PIM/Dialect/PIM/Pim.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE
def memrefCopyToPimMemCopyOp : Pat<

View File

@@ -7,12 +7,37 @@
#include "Common/PIMCommon.hpp"
#include "Compiler/PimCodeGen.hpp"
#include "PimBufferizationPass.hpp"
#include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
using namespace onnx_mlir;
using namespace pim;
namespace onnx_mlir {
namespace {
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
StringRef getArgument() const override { return "bufferize-pim"; }
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
PimBufferizationPass() = default;
PimBufferizationPass(const PimBufferizationPass& pass) {}
void runOnOperation() final;
private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
};
} // namespace
void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation();
@@ -68,15 +93,18 @@ void PimBufferizationPass::runOnOperation() {
}
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
if (isAlwaysWeight) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
assert("Weights must be constants" && globalMemrefOp.getConstant());
getGlobalOp->setAttr("weightAlways", UnitAttr::get(ctx));
globalMemrefOp->setAttr("weightAlways", UnitAttr::get(ctx));
markWeightAlways(getGlobalOp);
markWeightAlways(globalMemrefOp);
}
});
}
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); }
} // namespace onnx_mlir

View File

@@ -25,7 +25,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -202,9 +202,9 @@ LogicalResult SpatVMaxOp::verify() {
LogicalResult SpatImgConcatOp::verify() {
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
size_t channelTileRest = img_c % crossbarSize;
@@ -225,10 +225,10 @@ LogicalResult SpatImgConcatOp::verify() {
return emitError("Invalid input type, must be ShapedType");
// N == W == H == 1
if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1)
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
return emitError("Invalid input shape: N,W,H must all be 1");
size_t inputChannels = GET_IMAGE_CHANNEL(inputShape);
size_t inputChannels = getImageChannel(inputShape);
// Check the number of channels in this tile are correct:
// - CASE1: last tile of pixel, if there is some rest it must match that
@@ -311,9 +311,9 @@ LogicalResult SpatWeightedCompute::verify() {
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
auto operands = getOperands();
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());

View File

@@ -18,7 +18,7 @@
#include <cstdint>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"

View File

@@ -4,14 +4,12 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -1,7 +1,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerUtils.hpp"

View File

@@ -0,0 +1,181 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides(rank, 1);
SmallVector<int64_t> transposedStrides(rank, 1);
for (int64_t dim = rank - 2; dim >= 0; --dim) {
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
}
SmallVector<int64_t> originalIndices(rank);
SmallVector<int64_t> transposedIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
for (int64_t dim = 0; dim < rank; ++dim)
transposedIndices[dim] = originalIndices[perms[dim]];
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
if (!sourceGetGlobal)
return failure();
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
if (!denseAttr)
return failure();
SmallVector<int64_t> perms;
perms.reserve(transposeOp.getPerms().size());
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
perms.push_back(attr.getInt());
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
if (failed(transposedAttr))
return failure();
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
if (!llvm::equal(transposedShape, resultType.getShape()))
return failure();
MemRefType globalType = resultType;
auto globalName = sourceGlobal.getName().str() + "__folded_transpose";
unsigned suffix = 0;
while (moduleOp.lookupSymbol(globalName))
globalName = sourceGlobal.getName().str() + "__folded_transpose_" + std::to_string(++suffix);
auto visibility = rewriter.getStringAttr("private");
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
auto newGlobal = memref::GlobalOp::create(moduleBuilder,
transposeOp.getLoc(),
globalName,
visibility,
globalType,
*transposedAttr,
/*constant=*/true,
sourceGlobal.getAlignmentAttr());
rewriter.setInsertionPoint(transposeOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
bool isAlwaysWeight =
!transposeOp->getUsers().empty()
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
if (isAlwaysWeight) {
markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal);
}
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
return success();
}
};
struct PimFoldHostConstantsPass : PassWrapper<PimFoldHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimFoldHostConstantsPass)
StringRef getArgument() const override { return "fold-pim-host-constants-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override {
RewritePatternSet owningPatterns(context);
for (auto* dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
owningPatterns.add<FoldConstantTransposePattern>(context);
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();
}
void runOnOperation() override {
GreedyRewriteConfig config;
config.enableFolding();
if (failed(applyPatternsGreedily(getOperation(), *patterns, config)))
signalPassFailure();
}
std::shared_ptr<const FrozenRewritePatternSet> patterns;
};
} // namespace
std::unique_ptr<Pass> createPimFoldHostConstantsPass() { return std::make_unique<PimFoldHostConstantsPass>(); }
} // namespace onnx_mlir

View File

@@ -0,0 +1,173 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isAddressOnlyHostOp(Operation* op) {
return isa<memref::AllocOp,
memref::GetGlobalOp,
memref::SubViewOp,
memref::CastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp,
spatial::SpatChannelNewOp>(op);
}
static bool isHostAddressableValue(Value value) {
while (true) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return isa<func::FuncOp>(blockArg.getOwner()->getParentOp());
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return true;
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
return false;
}
}
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass)
StringRef getArgument() const override { return "verify-pim-host-pass"; }
StringRef getDescription() const override {
return "Verify that no runtime host-side code remains in bufferized PIM IR";
}
PimHostVerificationPass() {}
PimHostVerificationPass(const PimHostVerificationPass& pass) {}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (Operation& op : funcOp.getBody().front().getOperations()) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
if (failed(verifyCoreWeights(moduleOp, coreOp)))
hasFailure = true;
continue;
}
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
if (failed(verifyReturnOp(returnOp)))
hasFailure = true;
continue;
}
if (!isAddressOnlyHostOp(&op)) {
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; "
"fold it to constants or lower it into pim.core");
hasFailure = true;
continue;
}
if (failed(verifyAddressOnlyHostOp(&op)))
hasFailure = true;
}
}
if (hasFailure)
signalPassFailure();
}
private:
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, pim::PimCoreOp coreOp) {
bool hasFailure = false;
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must be materialized as memref.get_global before JSON codegen";
hasFailure = true;
continue;
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
hasFailure = true;
continue;
}
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must come from a constant memref.global with an initial value";
hasFailure = true;
}
}
return success(!hasFailure);
}
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
bool hasFailure = false;
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
if (!isHostAddressableValue(operand)) {
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by host-addressable storage";
hasFailure = true;
}
}
return success(!hasFailure);
}
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
return verifyAddressOnlySource(op, subviewOp.getSource());
if (auto castOp = dyn_cast<memref::CastOp>(op))
return verifyAddressOnlySource(op, castOp.getSource());
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
return verifyAddressOnlySource(op, collapseOp.getSrc());
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
return verifyAddressOnlySource(op, expandOp.getSrc());
return success();
}
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
if (isHostAddressableValue(source))
return success();
op->emitOpError("depends on a value that still requires host-side execution");
return failure();
}
};
} // namespace
std::unique_ptr<Pass> createPimHostVerificationPass() { return std::make_unique<PimHostVerificationPass>(); }
} // namespace onnx_mlir

View File

@@ -3,23 +3,26 @@
#include "mlir/Pass/Pass.h"
#include <memory>
using namespace mlir;
#include <string>
namespace onnx_mlir {
std::unique_ptr<Pass> createONNXToSpatialPass();
std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
std::unique_ptr<Pass> createSpatialToGraphvizPass();
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
std::unique_ptr<Pass> createSpatialToPIMPass();
std::unique_ptr<mlir::Pass> createSpatialToPIMPass();
std::unique_ptr<Pass> createBufferizePimPass();
std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<Pass> createEmitPimJsonPass();
std::unique_ptr<mlir::Pass> createPimFoldHostConstantsPass();
std::unique_ptr<Pass> createMessagePass(std::string message);
std::unique_ptr<mlir::Pass> createPimHostVerificationPass();
std::unique_ptr<Pass> createCountInstructionPass();
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
std::unique_ptr<mlir::Pass> createMessagePass(std::string message);
std::unique_ptr<mlir::Pass> createCountInstructionPass();
} // namespace onnx_mlir

View File

@@ -4,6 +4,7 @@
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
@@ -12,8 +13,8 @@
#include "llvm/Support/Debug.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
@@ -40,28 +41,27 @@ PimAccelerator::PimAccelerator()
acceleratorTargets.push_back(this);
}
PimAccelerator::~PimAccelerator() { delete instance; }
uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; }
void PimAccelerator::addPasses(OwningOpRef<ModuleOp>& module,
PassManager& pm,
void PimAccelerator::addPasses(mlir::OwningOpRef<mlir::ModuleOp>& module,
mlir::PassManager& pm,
EmissionTargetType& emissionTarget,
std::string outputNameNoExt) const {
LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n");
addPassesPim(module, pm, emissionTarget, outputNameNoExt);
}
void PimAccelerator::registerDialects(DialectRegistry& registry) const {
void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n");
registry.insert<tensor::TensorDialect>();
registry.insert<tosa::TosaDialect>();
registry.insert<bufferization::BufferizationDialect>();
registry.insert<mlir::tensor::TensorDialect>();
registry.insert<mlir::tosa::TosaDialect>();
registry.insert<mlir::bufferization::BufferizationDialect>();
registry.insert<pim::PimDialect>();
registry.insert<spatial::SpatialDialect>();
tensor::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry);
mlir::tensor::registerInferTypeOpInterfaceExternalModels(registry);
mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
pim::registerOpBufferizationInterfaces(registry);
@@ -73,6 +73,8 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPIMPass);
registerPass(createBufferizePimPass);
registerPass(createPimFoldHostConstantsPass);
registerPass(createPimHostVerificationPass);
registerPass(createEmitPimJsonPass);
}
@@ -81,26 +83,26 @@ void PimAccelerator::configurePasses() const {
// TODO: This does nothing for now.
}
MemRefType PimAccelerator::convertTensorTypeToMemRefType(const TensorType tensorType) const {
mlir::MemRefType PimAccelerator::convertTensorTypeToMemRefType(const mlir::TensorType tensorType) const {
// Do not convert tensor types to memref types.
return nullptr;
}
void PimAccelerator::conversionTargetONNXToKrnl(ConversionTarget& target) const {
void PimAccelerator::conversionTargetONNXToKrnl(mlir::ConversionTarget& target) const {
target.addLegalDialect<pim::PimDialect>();
}
void PimAccelerator::rewritePatternONNXToKrnl(RewritePatternSet& patterns,
TypeConverter& typeConverter,
MLIRContext* ctx) const {
void PimAccelerator::rewritePatternONNXToKrnl(mlir::RewritePatternSet& patterns,
mlir::TypeConverter& typeConverter,
mlir::MLIRContext* ctx) const {
// TODO: Add patterns for conversion
}
void PimAccelerator::conversionTargetKrnlToLLVM(ConversionTarget& target) const {}
void PimAccelerator::conversionTargetKrnlToLLVM(mlir::ConversionTarget& target) const {}
void PimAccelerator::rewritePatternKrnlToLLVM(RewritePatternSet& patterns,
LLVMTypeConverter& typeConverter,
MLIRContext* ctx) const {
void PimAccelerator::rewritePatternKrnlToLLVM(mlir::RewritePatternSet& patterns,
mlir::LLVMTypeConverter& typeConverter,
mlir::MLIRContext* ctx) const {
// We should not need this, since we offload it all to PIM.
}

View File

@@ -18,8 +18,6 @@ public:
PimAccelerator(PimAccelerator&) = delete;
void operator=(const PimAccelerator&) = delete;
~PimAccelerator();
/// Creates an instance on the first invocation. Subsequent invocations
/// return the existing instance.
static PimAccelerator* getInstance();

View File

@@ -102,46 +102,13 @@ def gen_c(inputs, outputs, entry, so_name):
if(!in{i}_tensor){{fprintf(stderr,"ERROR: omTensorCreateWithOwnership failed for input {i}.\\n");return 2;}}
"""))
# Output printing + optional per-output CSV dump
out_blocks=[]
# Optional per-output CSV dump
csv_write_blocks=[]
for oi,name,et,shape in outputs:
if et not in DTYPES:
raise ValueError(f"Unsupported dtype for output '{name}': {et}")
cty, pfmt, _ = DTYPES[et]
safe = esc(name)
out_blocks.append(textwrap.dedent(f"""
// ---- Output {oi}: "{safe}" ----
{{
OMTensor *t = omTensorListGetOmtByIndex(out_list, {oi});
int64_t rank = omTensorGetRank(t);
int64_t const *shape = omTensorGetShape(t);
long long numel = 1; for (int64_t k=0;k<rank;k++) numel *= shape[k];
{cty} *p = ({cty}*)omTensorGetDataPtr(t);
printf("Output {oi} ('{safe}'): shape=[");
for (int64_t k=0;k<rank;k++) printf("%ld%s",(long)shape[k], (k+1<rank)?",":"");
printf("]\\n");
if (rank == 2) {{
int64_t R = shape[0], C = shape[1];
for (int64_t r=0; r<R; ++r) {{
for (int64_t c=0; c<C; ++c) {{
long long idx = r*C + c;
printf("{pfmt}%s", p[idx], (c+1<C)?", ":"");
}}
printf("\\n");
}}
}} else {{
// Flattened vector with indices
for (long long i=0;i<numel;i++) {{
printf("[%lld]={pfmt}%s", i, p[i], (i+1<numel)?", ":"\\n");
}}
}}
}}
"""))
# Per-output CSV writer into --save-csv-dir
csv_write_blocks.append(textwrap.dedent(f"""
if (save_csv_dir) {{
// Build "DIR/output{oi}_<sanitized name>.csv"
@@ -227,9 +194,6 @@ int main(int argc, char **argv) {{
OMTensorList *out_list = {entry}(in_list);
if(!out_list){{fprintf(stderr,"ERROR: model returned NULL.\\n");omTensorListDestroy(in_list);return 3;}}
// ---- Print full outputs ----
{"".join(out_blocks)}
// ---- Optional per-output CSV dump ----
{"".join(csv_write_blocks)}
@@ -240,7 +204,7 @@ int main(int argc, char **argv) {{
}}
"""
def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None):
def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None, verbose=True):
ins, outs = onnx_io(network_onnx)
out_c = out or "runner.c"
so_abs = os.path.abspath(network_so)
@@ -260,6 +224,7 @@ set_target_properties(model_so PROPERTIES IMPORTED_LOCATION {esc(so_abs)})
target_link_libraries({pathlib.Path(out_c).stem} PUBLIC model_so)
"""
pathlib.Path(out_c).with_name("CMakeLists.txt").write_text(cmake)
if verbose:
print(f"[OK] Wrote {out_c}")
print("[OK] Wrote CMakeLists.txt")

View File

@@ -1,9 +1,10 @@
import subprocess
from pathlib import Path
from colorama import Fore, Style
from subprocess_utils import run_command_with_reporter
def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, crossbar_count):
def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, crossbar_count, reporter=None):
# Define the arguments, with the possibility to set crossbar size and count
args = [
network_path,
@@ -14,16 +15,14 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, cro
f"--crossbar-count={crossbar_count}",
]
# Run the executable with the arguments
try:
result = subprocess.run(
run_command_with_reporter(
[str(raptor_onnx_path)] + [str(arg) for arg in args],
check=True,
capture_output=True,
text=True,
reporter=reporter,
)
print(result.stdout + Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
except subprocess.CalledProcessError as e:
print(Fore.RED + "Error executing ONNX-MLIR:")
print(e.stderr + Style.RESET_ALL)
if reporter is None:
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
except subprocess.CalledProcessError:
if reporter is None:
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)
raise

View File

@@ -0,0 +1,70 @@
import errno
import os
import pty
import selectors
import subprocess
def _read_chunk(fd, treat_eio_as_eof=False):
try:
return os.read(fd, 4096)
except OSError as exc:
if treat_eio_as_eof and exc.errno == errno.EIO:
return b""
raise
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
selector = selectors.DefaultSelector()
try:
selector.register(fd, selectors.EVENT_READ)
while selector.get_map():
for key, _ in selector.select():
data = _read_chunk(key.fileobj, treat_eio_as_eof=treat_eio_as_eof)
if not data:
selector.unregister(key.fileobj)
os.close(key.fileobj)
continue
reporter._clear()
os.write(1, data)
reporter._render()
finally:
selector.close()
return_code = process.wait()
if return_code != 0:
raise subprocess.CalledProcessError(return_code, process.args)
def run_command_with_reporter(cmd, cwd=None, reporter=None):
if reporter is None:
subprocess.run(cmd, cwd=cwd, check=True)
return
try:
master_fd, slave_fd = pty.openpty()
except OSError:
process = subprocess.Popen(
cmd,
cwd=cwd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
assert process.stdout is not None
_stream_output(process.stdout.fileno(), process, reporter)
return
try:
process = subprocess.Popen(
cmd,
cwd=cwd,
stdout=slave_fd,
stderr=slave_fd,
)
finally:
os.close(slave_fd)
_stream_output(master_fd, process, reporter, treat_eio_as_eof=True)

View File

@@ -1,10 +1,11 @@
#!/usr/bin/env python3
import argparse
import subprocess
import sys
from pathlib import Path
from colorama import Style, Fore
from validate_one import validate_network
from validate_one import ProgressReporter, validate_network
def main():
@@ -34,32 +35,48 @@ def main():
print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL)
sys.exit(1)
print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate.\n" + Style.RESET_ALL)
print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate." + Style.RESET_ALL)
print(f"Operations root: {operations_dir}")
print("=" * 72)
results = {} # relative_path -> passed
for onnx_path in onnx_files:
reporter = ProgressReporter(len(onnx_files))
for index, onnx_path in enumerate(onnx_files, start=1):
rel = onnx_path.relative_to(operations_dir)
header = f"{'=' * 60}\n Validating: {rel}\n{'=' * 60}"
print(Style.BRIGHT + Fore.CYAN + header + Style.RESET_ALL)
try:
passed = validate_network(
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
threshold=a.threshold,
reporter=reporter,
model_index=index,
model_total=len(onnx_files),
)
results[str(rel)] = passed
except (subprocess.CalledProcessError, Exception):
results[str(rel)] = False
reporter.finish()
# Summary
n_passed = sum(results.values())
n_passed = sum(1 for passed in results.values() if passed)
n_total = len(results)
print("\n" + Style.BRIGHT + "=" * 60)
print(" Summary")
print("=" * 60 + Style.RESET_ALL)
status_width = len("Result")
path_width = max(len("Operation"), *(len(rel) for rel in results))
separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+"
print("\n" + Style.BRIGHT + Fore.CYAN + "Summary" + Style.RESET_ALL)
print(separator)
print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |")
print(separator)
for rel, passed in results.items():
status = Fore.GREEN + "PASS" if passed else Fore.RED + "FAIL"
print(f" {rel}: {status}" + Style.RESET_ALL)
print(Style.BRIGHT + f"\n {n_passed}/{n_total} passed." + Style.RESET_ALL)
plain_status = "PASS" if passed else "FAIL"
status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \
Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
print(f"| {rel.ljust(path_width)} | {status} |")
print(separator)
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
sys.exit(0 if n_passed == n_total else 1)

View File

@@ -2,16 +2,114 @@ import argparse
import json
import numpy as np
import subprocess
import shutil
import sys
from pathlib import Path
from colorama import Style, Fore
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
from raptor import compile_with_raptor
from gen_network_runner import gen_network_runner
from subprocess_utils import run_command_with_reporter
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir):
subprocess.run([raptor_path, network_onnx_path, "--EmitONNXIR"], check=True)
subprocess.run([raptor_path, network_onnx_path], check=True)
STAGE_COUNT = 6
class ProgressReporter:
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
self.total_models = total_models
self.stages_per_model = stages_per_model
self.total_steps = max(1, total_models * stages_per_model)
self.completed_steps = 0
self.current_label = ""
self.enabled = True
self.columns = shutil.get_terminal_size((100, 20)).columns
self.suspended = False
def _clear(self):
if self.enabled:
sys.stdout.write("\033[2K\r")
def _render(self):
if not self.enabled or self.suspended:
return
bar_width = 24
filled = int(bar_width * self.completed_steps / self.total_steps)
prefix_text = f"[{'#' * filled}{'-' * (bar_width - filled)}] {self.completed_steps}/{self.total_steps}"
if len(prefix_text) > self.columns:
prefix_text = f"{self.completed_steps}/{self.total_steps}"
label = f" {self.current_label}" if self.current_label else ""
available_label_width = max(0, self.columns - len(prefix_text))
label = label[:available_label_width]
if prefix_text.startswith("["):
bar = Fore.GREEN + ("#" * filled) + Fore.CYAN + ("-" * (bar_width - filled))
prefix = Fore.CYAN + f"[{bar}{Fore.CYAN}] {self.completed_steps}/{self.total_steps}" + Style.RESET_ALL
else:
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
sys.stdout.write("\r" + prefix + label + Style.RESET_ALL)
sys.stdout.flush()
def log(self, message="", color=None):
if self.enabled:
self._clear()
if color:
print(color + message + Style.RESET_ALL)
else:
print(message)
self._render()
def set_stage(self, model_index, model_total, model_name, stage_name):
self.current_label = f"[{model_index}/{model_total}] {model_name} · {stage_name}"
self._render()
def advance(self):
self.completed_steps = min(self.total_steps, self.completed_steps + 1)
self._render()
def suspend(self):
self.suspended = True
self._clear()
def resume(self):
self.suspended = False
self._render()
def finish(self):
if self.enabled:
self.suspended = True
self._clear()
sys.stdout.flush()
def run_command(cmd, cwd=None, reporter=None):
run_command_with_reporter(cmd, cwd=cwd, reporter=reporter)
def print_stage(reporter, model_index, model_total, model_name, title):
stage_colors = {
"Compile ONNX": Fore.BLUE,
"Build Runner": Fore.MAGENTA,
"Generate Inputs": Fore.YELLOW,
"Run Reference": Fore.GREEN,
"Compile PIM": Fore.CYAN,
"Run Simulator": Fore.MAGENTA,
"Compare Outputs": Fore.YELLOW,
}
color = stage_colors.get(title, Fore.WHITE)
reporter.log(Style.BRIGHT + color + f"[{title}]" + Style.RESET_ALL)
reporter.set_stage(model_index, model_total, model_name, title)
def print_info(reporter, message):
reporter.log(f" {message}")
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=None):
run_command([raptor_path, network_onnx_path, "--EmitONNXIR"], reporter=reporter)
run_command([raptor_path, network_onnx_path], reporter=reporter)
parent = network_onnx_path.parent
stem = network_onnx_path.stem
so_path = parent / f"{stem}.so"
@@ -25,9 +123,9 @@ def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir)
return moved_so, moved_mlir
def build_onnx_runner(source_dir, build_dir):
subprocess.run(["cmake", source_dir], cwd=build_dir, check=True)
subprocess.run(["cmake", "--build", ".", "-j"], cwd=build_dir, check=True)
def build_onnx_runner(source_dir, build_dir, reporter=None):
run_command(["cmake", source_dir], cwd=build_dir, reporter=reporter)
run_command(["cmake", "--build", ".", "-j"], cwd=build_dir, reporter=reporter)
return build_dir / "runner"
@@ -41,11 +139,12 @@ def build_dump_ranges(config_path, outputs_descriptor):
return ",".join(ranges)
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges):
subprocess.run(
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
run_command(
["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
cwd=simulator_dir, check=True
cwd=simulator_dir,
reporter=reporter,
)
@@ -64,24 +163,41 @@ def parse_pim_simulator_outputs(output_bin_path, outputs_descriptor):
def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3):
all_passed = True
rows = []
for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor):
csv_name = f"output{oi}_{name}.csv"
runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape)
max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64))))
passed = max_diff <= threshold
status = Fore.GREEN + "[PASS]" if passed else Fore.RED + "[FAIL]"
print(f" {name}: max diff = {max_diff:.6e} {status}" + Style.RESET_ALL)
rows.append((name, f"{max_diff:.6e}", passed))
if not passed:
all_passed = False
name_width = max(len("Output"), *(len(name) for name, _, _ in rows))
diff_width = max(len("Max diff"), *(len(diff) for _, diff, _ in rows))
result_width = len("Result")
separator = f" +-{'-' * name_width}-+-{'-' * diff_width}-+-{'-' * result_width}-+"
print(separator)
print(f" | {'Output'.ljust(name_width)} | {'Max diff'.ljust(diff_width)} | {'Result'} |")
print(separator)
for name, diff_text, passed in rows:
status_text = ("PASS" if passed else "FAIL").ljust(result_width)
status = Fore.GREEN + status_text + Style.RESET_ALL if passed else Fore.RED + status_text + Style.RESET_ALL
print(f" | {name.ljust(name_width)} | {diff_text.ljust(diff_width)} | {status} |")
print(separator)
return all_passed
def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3):
simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3,
reporter=None, model_index=1, model_total=1):
network_onnx_path = Path(network_onnx_path).resolve()
raptor_path = Path(raptor_path).resolve()
onnx_include_dir = Path(onnx_include_dir).resolve()
simulator_dir = Path(simulator_dir).resolve()
owns_reporter = reporter is None
reporter = reporter or ProgressReporter(model_total)
workspace_dir = network_onnx_path.parent
raptor_dir = workspace_dir / "raptor"
@@ -90,40 +206,72 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
Path.mkdir(raptor_dir, exist_ok=True)
Path.mkdir(runner_build_dir, parents=True, exist_ok=True)
print(Style.BRIGHT + "\nCompiling the onnx network:" + Style.RESET_ALL)
network_so_path, network_mlir_path = compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir)
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
print(Style.BRIGHT + "\nGenerating and building the runner:" + Style.RESET_ALL)
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c")
runner_path = build_onnx_runner(runner_dir, runner_build_dir)
try:
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
network_so_path, network_mlir_path = compile_onnx_network(
network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=reporter)
print_info(reporter, f"MLIR saved to {network_mlir_path}")
print_info(reporter, f"Shared library saved to {network_so_path}")
reporter.advance()
print(Style.BRIGHT + "\nGenerating random inputs:" + Style.RESET_ALL)
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", verbose=False)
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter)
print_info(reporter, f"Runner built at {runner_path}")
reporter.advance()
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Generate Inputs")
inputs_descriptor, outputs_descriptor = onnx_io(network_onnx_path)
inputs_list, _inputs_dict = gen_random_inputs(inputs_descriptor)
flags, _files = save_inputs_to_files(network_onnx_path, inputs_list, out_dir=workspace_dir / "inputs")
print_info(reporter, f"Saved {len(inputs_list)} input file(s) to {workspace_dir / 'inputs'}")
reporter.advance()
print(Style.BRIGHT + "\nRunning inference with the runner:" + Style.RESET_ALL)
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Run Reference")
out_dir = workspace_dir / "outputs"
Path.mkdir(out_dir, exist_ok=True)
run_cmd = [runner_path, *flags]
run_cmd += ["--save-csv-dir", f"{out_dir}"]
subprocess.run(run_cmd, cwd=runner_build_dir, check=True)
run_command(run_cmd, cwd=runner_build_dir, reporter=reporter)
print_info(reporter, f"Reference outputs saved to {out_dir}")
reporter.advance()
print(Style.BRIGHT + "\nCompiling for PIM with Raptor:" + Style.RESET_ALL)
compile_with_raptor(network_mlir_path, raptor_path, crossbar_size, crossbar_count)
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
compile_with_raptor(
network_mlir_path, raptor_path, crossbar_size, crossbar_count, reporter=reporter)
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
reporter.advance()
print(Style.BRIGHT + "\nRunning PIM simulation:" + Style.RESET_ALL)
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Run Simulator")
pim_dir = raptor_dir / "pim"
write_inputs_to_memory_bin(pim_dir / "memory.bin", pim_dir / "config.json", inputs_list)
simulation_dir = workspace_dir / "simulation"
Path.mkdir(simulation_dir, exist_ok=True)
dump_ranges = build_dump_ranges(pim_dir / "config.json", outputs_descriptor)
output_bin_path = simulation_dir / "out.bin"
run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges)
run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=reporter)
print_info(reporter, f"Simulator output saved to {output_bin_path}")
reporter.advance()
print(Style.BRIGHT + "\nValidating the results:" + Style.RESET_ALL)
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compare Outputs")
sim_arrays = parse_pim_simulator_outputs(output_bin_path, outputs_descriptor)
return validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
reporter.suspend()
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
reporter.resume()
reporter.advance()
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
return passed
except Exception:
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
raise
finally:
reporter.log("=" * 72)
if owns_reporter:
reporter.finish()
if __name__ == '__main__':