c77ffa9c56
support for tensors of index values
380 lines
16 KiB
C++
380 lines
16 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinDialect.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/SymbolTable.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/FoldUtils.h"
|
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
|
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
#include <cassert>
|
|
#include <utility>
|
|
|
|
#include "Common/PimCommon.hpp"
|
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
|
#include "Conversion/SpatialToPim/Common.hpp"
|
|
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
|
#include "Conversion/SpatialToPim/PhaseVerification.hpp"
|
|
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
|
#include "Dialect/Pim/PimOps.hpp"
|
|
#include "Dialect/Spatial/SpatialOps.hpp"
|
|
#include "Pass/PIMPasses.h"
|
|
#include "SpatialToPimPass.hpp"
|
|
|
|
using namespace mlir;
|
|
using namespace onnx_mlir;
|
|
using namespace pim;
|
|
|
|
namespace onnx_mlir {
|
|
namespace raptor {
|
|
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
|
|
|
} // namespace raptor
|
|
|
|
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
|
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
|
auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
|
auto zeroAttr = DenseElementsAttr::get(tensorType, rewriter.getZeroAttr(tensorType.getElementType()));
|
|
|
|
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
|
|
if (!globalOp.getConstant() || globalOp.getType() != memRefType || !globalOp.getInitialValue())
|
|
continue;
|
|
if (dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue()) == zeroAttr)
|
|
return globalOp;
|
|
}
|
|
|
|
std::string nameStem;
|
|
llvm::raw_string_ostream nameStream(nameStem);
|
|
nameStream << "__pim_zero_" << tensorType.getRank() << "d_" << tensorType.getNumElements();
|
|
nameStream.flush();
|
|
|
|
std::string symbolName = nameStem;
|
|
unsigned suffix = 0;
|
|
while (SymbolTable::lookupSymbolIn(moduleOp, symbolName))
|
|
symbolName = (nameStem + "_" + Twine(suffix++)).str();
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
|
return memref::GlobalOp::create(rewriter,
|
|
loc,
|
|
rewriter.getStringAttr(symbolName),
|
|
rewriter.getStringAttr("private"),
|
|
TypeAttr::get(memRefType),
|
|
zeroAttr,
|
|
rewriter.getUnitAttr(),
|
|
IntegerAttr {});
|
|
}
|
|
|
|
static Value createZeroedDeviceHVector(IRRewriter& rewriter,
|
|
Location loc,
|
|
RankedTensorType tensorType,
|
|
OperationFolder& constantFolder) {
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
|
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
|
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
|
auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder);
|
|
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
|
|
|
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
|
|
return PimMemCopyHostToDevBatchOp::create(rewriter,
|
|
loc,
|
|
tensorType,
|
|
outputBuffer,
|
|
zeroValue,
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(0),
|
|
sizeAttr)
|
|
.getOutput();
|
|
|
|
return PimMemCopyHostToDevOp::create(
|
|
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
|
|
.getOutput();
|
|
}
|
|
|
|
static Value
|
|
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
|
auto vectorType = cast<RankedTensorType>(vector.getType());
|
|
ArrayRef<int64_t> shape = vectorType.getShape();
|
|
assert(isHVectorShape(shape) && "expected a horizontal vector");
|
|
assert(shape[1] <= static_cast<int64_t>(crossbarSize) && "vector width must fit in one crossbar");
|
|
|
|
if (shape[1] == static_cast<int64_t>(crossbarSize))
|
|
return vector;
|
|
|
|
auto paddedType = RankedTensorType::get(
|
|
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
|
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
|
|
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
|
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
|
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
|
}
|
|
|
|
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|
coreId = 0;
|
|
outputTensors.clear();
|
|
operationsToRemove.clear();
|
|
ModuleOp moduleOp = getOperation();
|
|
MLIRContext* ctx = moduleOp.getContext();
|
|
|
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
|
if (failed(entryFunc)) {
|
|
moduleOp.emitError("failed to locate the PIM entry function during Spatial-to-PIM lowering");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
func::FuncOp funcOp = *entryFunc;
|
|
|
|
IRRewriter rewriter(&getContext());
|
|
OperationFolder constantFolder(&getContext());
|
|
|
|
ConversionTarget target(*ctx);
|
|
target.addLegalDialect<PimDialect,
|
|
tensor::TensorDialect,
|
|
arith::ArithDialect,
|
|
bufferization::BufferizationDialect,
|
|
func::FuncDialect,
|
|
memref::MemRefDialect,
|
|
scf::SCFDialect,
|
|
BuiltinDialect>();
|
|
target.addLegalOp<spatial::SpatConcatOp,
|
|
spatial::SpatChannelReceiveOp,
|
|
spatial::SpatChannelReceiveTensorOp,
|
|
spatial::SpatChannelReceiveTensorBatchOp,
|
|
spatial::SpatChannelSendOp,
|
|
spatial::SpatChannelSendTensorOp,
|
|
spatial::SpatChannelSendTensorBatchOp,
|
|
spatial::SpatExtractRowsOp>();
|
|
|
|
RewritePatternSet initialPatterns(ctx);
|
|
populateWithGenerated(initialPatterns);
|
|
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
|
|
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
RewritePatternSet globalTensorPatterns(ctx);
|
|
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
|
|
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
|
|
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
|
addReturnOutputBuffers(returnOp, rewriter);
|
|
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
|
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
|
markOpToRemove(computeOp);
|
|
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
|
|
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
|
markOpToRemove(computeBatchOp);
|
|
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
|
|
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
RewritePatternSet initialTensorPackingPatterns(ctx);
|
|
populateTensorPackingPatterns(initialTensorPackingPatterns);
|
|
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
|
|
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
|
receiveOps.push_back(op);
|
|
for (auto receiveOp : receiveOps) {
|
|
bool onlyPendingRemovalUsers = llvm::all_of(
|
|
receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); });
|
|
if (onlyPendingRemovalUsers) {
|
|
markOpToRemove(receiveOp);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
RewritePatternSet coreBodyPatterns(ctx);
|
|
populateWithGenerated(coreBodyPatterns);
|
|
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
|
|
|
SmallVector<pim::PimCoreOp> coreOps;
|
|
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
|
for (auto coreOp : coreOps) {
|
|
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
|
coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
|
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
|
for (auto coreBatchOp : coreBatchOps) {
|
|
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
|
coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
|
replaceReturnWithOutputBuffers(returnOp, rewriter);
|
|
eraseOpsToRemove();
|
|
|
|
RewritePatternSet finalTensorPackingPatterns(ctx);
|
|
populateTensorPackingPatterns(finalTensorPackingPatterns);
|
|
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
|
|
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
|
|
|
ConversionTarget communicationTarget(*ctx);
|
|
communicationTarget.addLegalDialect<PimDialect,
|
|
tensor::TensorDialect,
|
|
arith::ArithDialect,
|
|
bufferization::BufferizationDialect,
|
|
func::FuncDialect,
|
|
memref::MemRefDialect,
|
|
scf::SCFDialect,
|
|
BuiltinDialect>();
|
|
communicationTarget.addLegalOp<ModuleOp>();
|
|
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
|
|
spatial::SpatChannelReceiveOp,
|
|
spatial::SpatChannelReceiveTensorOp,
|
|
spatial::SpatChannelSendOp,
|
|
spatial::SpatChannelSendTensorOp,
|
|
spatial::SpatExtractRowsOp>();
|
|
|
|
RewritePatternSet communicationPatterns(ctx);
|
|
populateChannelLoweringPatterns(communicationPatterns);
|
|
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
|
|
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
if (failed(verifySpatialToPimBoundary(moduleOp))) {
|
|
moduleOp.emitError("Spatial-to-PIM boundary verification failed");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
// Dump to file for debug
|
|
dumpModule(moduleOp, "pim0");
|
|
}
|
|
|
|
void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
OperationFolder constantFolder(funcOp.getContext());
|
|
funcOp.walk([&](PimVMMOp vmmOp) {
|
|
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
assert(isHVectorShape(outputShape) && "expected a horizontal vector output");
|
|
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
|
|
|
|
rewriter.setInsertionPoint(vmmOp);
|
|
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
|
|
auto paddedOutputType = RankedTensorType::get(
|
|
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
|
|
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
|
|
? vmmOp.getOutputBuffer()
|
|
: createEmptyTensorFromShaped(rewriter, vmmOp.getLoc(), paddedOutputType).getResult();
|
|
vmmOp.getInputMutable().assign(paddedInput);
|
|
vmmOp.getOutputBufferMutable().assign(paddedOutputBuffer);
|
|
|
|
vmmOp.getOutput().setType(paddedOutputType);
|
|
|
|
if (outputShape[1] == static_cast<int64_t>(crossbarSize))
|
|
return;
|
|
|
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputShape[0]), rewriter.getIndexAttr(outputShape[1])};
|
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
rewriter.setInsertionPointAfter(vmmOp);
|
|
auto sliceOp =
|
|
tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), outputType, vmmOp.getOutput(), offsets, sizes, strides);
|
|
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
|
vmmOp.getOutput().replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
|
});
|
|
}
|
|
|
|
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
|
|
IRRewriter& rewriter) {
|
|
Location loc = funcOp.getLoc();
|
|
OperationFolder constantFolder(funcOp.getContext());
|
|
|
|
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
|
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
|
Type elementType = tensorType.getElementType();
|
|
if (!hasByteSizedElementType(elementType))
|
|
return;
|
|
size_t elementByteSize = getElementTypeSizeInBytes(elementType);
|
|
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
|
|
|
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
|
|
|
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
|
|
rewriter,
|
|
loc,
|
|
tensorType,
|
|
getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder),
|
|
getOrCreateHostIndexConstant(
|
|
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize), constantFolder),
|
|
deviceTensor,
|
|
inputTensor,
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
|
|
|
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
|
};
|
|
|
|
for (auto& op : funcOp.getBody().getOps())
|
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
|
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
|
|
continue;
|
|
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
|
|
if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) {
|
|
assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute");
|
|
auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin();
|
|
insertMemCopyHostToDev(toTensorOpValue, 0);
|
|
}
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
|
|
if (!llvm::is_contained(operationsToRemove, op))
|
|
operationsToRemove.push_back(op);
|
|
}
|
|
|
|
void raptor::SpatialToPimPass::eraseOpsToRemove() {
|
|
for (Operation* op : operationsToRemove) {
|
|
op->dropAllUses();
|
|
op->erase();
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<raptor::SpatialToPimPass>(); }
|
|
|
|
} // namespace onnx_mlir
|