This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
@@ -12,6 +13,8 @@
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
@@ -104,23 +107,34 @@ static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc
|
||||
IntegerAttr {});
|
||||
}
|
||||
|
||||
static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
static Value createZeroedDeviceHVector(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
RankedTensorType tensorType,
|
||||
OperationFolder& constantFolder) {
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
||||
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
||||
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||
auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
||||
|
||||
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
|
||||
return PimMemCopyHostToDevBatchOp::create(
|
||||
rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
||||
return PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
loc,
|
||||
tensorType,
|
||||
outputBuffer,
|
||||
zeroValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
sizeAttr)
|
||||
.getOutput();
|
||||
|
||||
return PimMemCopyHostToDevOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
||||
return PimMemCopyHostToDevOp::create(
|
||||
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
|
||||
static Value
|
||||
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
||||
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||
ArrayRef<int64_t> shape = vectorType.getShape();
|
||||
assert(isHVectorShape(shape) && "expected a horizontal vector");
|
||||
@@ -131,7 +145,7 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V
|
||||
|
||||
auto paddedType = RankedTensorType::get(
|
||||
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
||||
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType);
|
||||
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
|
||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||
@@ -151,6 +165,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
func::FuncOp funcOp = *entryFunc;
|
||||
|
||||
IRRewriter rewriter(&getContext());
|
||||
OperationFolder constantFolder(&getContext());
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect,
|
||||
@@ -181,16 +196,17 @@ void SpatialToPimPass::runOnOperation() {
|
||||
RewritePatternSet globalTensorPatterns(ctx);
|
||||
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
|
||||
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
||||
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove};
|
||||
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove, constantFolder};
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
markOpToRemove(computeOp);
|
||||
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
||||
@@ -251,7 +267,6 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
||||
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
|
||||
|
||||
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
||||
@@ -302,6 +317,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
@@ -309,7 +325,7 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
|
||||
|
||||
rewriter.setInsertionPoint(vmmOp);
|
||||
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
|
||||
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
|
||||
auto paddedOutputType = RankedTensorType::get(
|
||||
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
|
||||
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
|
||||
@@ -336,10 +352,13 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
|
||||
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
Location loc = funcOp.getLoc();
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
||||
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
||||
Type elementType = tensorType.getElementType();
|
||||
if (!elementType.isIntOrFloat())
|
||||
return;
|
||||
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
||||
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
||||
|
||||
@@ -349,10 +368,11 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
rewriter,
|
||||
loc,
|
||||
tensorType,
|
||||
getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder),
|
||||
getOrCreateHostIndexConstant(
|
||||
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize), constantFolder),
|
||||
deviceTensor,
|
||||
inputTensor,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
||||
|
||||
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
||||
|
||||
Reference in New Issue
Block a user