diff --git a/src/PIM/Common/CMakeLists.txt b/src/PIM/Common/CMakeLists.txt index 8f1b44c..74f52cf 100644 --- a/src/PIM/Common/CMakeLists.txt +++ b/src/PIM/Common/CMakeLists.txt @@ -3,6 +3,7 @@ add_pim_library(OMPimCommon IR/CoreBlockUtils.cpp IR/EntryPointUtils.cpp IR/ShapeUtils.cpp + IR/SubviewUtils.cpp IR/WeightUtils.cpp Support/DebugDump.cpp Support/Diagnostics.cpp diff --git a/src/PIM/Common/IR/SubviewUtils.cpp b/src/PIM/Common/IR/SubviewUtils.cpp new file mode 100644 index 0000000..6284b00 --- /dev/null +++ b/src/PIM/Common/IR/SubviewUtils.cpp @@ -0,0 +1,85 @@ +#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" + +#include "mlir/IR/BuiltinTypeInterfaces.h" + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +Value stripMemRefCasts(Value value) { + while (auto castOp = value.getDefiningOp()) + value = castOp.getSource(); + return value; +} + +Value stripMemRefViewOps(Value value) { + while (true) { + if (auto castOp = value.getDefiningOp()) { + value = castOp.getSource(); + continue; + } + if (auto collapseOp = value.getDefiningOp()) { + value = collapseOp.getSrc(); + continue; + } + if (auto expandOp = value.getDefiningOp()) { + value = expandOp.getSrc(); + continue; + } + return value; + } +} + +bool hasAllStaticSubviewParts(memref::SubViewOp subview) { + return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); }) + && llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); }) + && llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); }); +} + +FailureOr getStaticSubviewInfo(Value value) { + value = stripMemRefViewOps(value); + auto subviewOp = value.getDefiningOp(); + if (!subviewOp) + return failure(); + + auto source = stripMemRefCasts(subviewOp.getSource()); + auto sourceType = dyn_cast(source.getType()); + auto subviewType = dyn_cast(subviewOp.getType()); + if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) + return failure(); + + StaticSubviewInfo info; + info.source = source; + info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end()); + SmallVector mixedOffsets = subviewOp.getMixedOffsets(); + info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end()); + for (OpFoldResult size : subviewOp.getMixedSizes()) { + auto staticSize = getConstantIntValue(size); + if (!staticSize) + return failure(); + info.sizes.push_back(*staticSize); + } + for (OpFoldResult stride : subviewOp.getMixedStrides()) { + auto staticStride = getConstantIntValue(stride); + if (!staticStride) + return failure(); + info.strides.push_back(*staticStride); + } + return info; +} + +FailureOr> getStaticSubviewOffsets(const StaticSubviewInfo& info) { + SmallVector staticOffsets; + staticOffsets.reserve(info.offsets.size()); + for (OpFoldResult offset : info.offsets) { + auto staticOffset = getConstantIntValue(offset); + if (!staticOffset) + return failure(); + staticOffsets.push_back(*staticOffset); + } + return staticOffsets; +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/SubviewUtils.hpp b/src/PIM/Common/IR/SubviewUtils.hpp new file mode 100644 index 0000000..de53782 --- /dev/null +++ b/src/PIM/Common/IR/SubviewUtils.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Value.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" + +namespace onnx_mlir { + +struct StaticSubviewInfo { + mlir::Value source; + llvm::SmallVector sourceShape; + llvm::SmallVector offsets; + llvm::SmallVector sizes; + llvm::SmallVector strides; +}; + +mlir::Value stripMemRefCasts(mlir::Value value); + +mlir::Value stripMemRefViewOps(mlir::Value value); + +bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview); + +llvm::FailureOr getStaticSubviewInfo(mlir::Value value); + +/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic. +llvm::FailureOr> getStaticSubviewOffsets(const StaticSubviewInfo& info); + +} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index d911d4b..7bc6409 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -24,8 +24,8 @@ #include #include -#include "Common/PimCommon.hpp" #include "Common/IR/CompactAsmUtils.hpp" +#include "Common/PimCommon.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp" #include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" @@ -206,9 +206,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu return iter->second.address + resolvedAddress->byteOffset; } -void PimAcceleratorMemory::reportHost() { - hostReportRow = hostMem.getReportRow(); -} +void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); } void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) { reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast(coreId)}, row}); @@ -810,7 +808,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { else if (auto getGlobalOp = dyn_cast(op)) coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge); else { - op.emitError("Unsupported codegen for this operation"); + InFlightDiagnostic diag = op.emitError() + << "unsupported codegen for op '" << op.getName().getStringRef() << "'"; + if (auto coreOp = op.getParentOfType()) + diag << " inside pim.core " << coreOp.getCoreId(); + else if (auto coreBatchOp = op.getParentOfType()) + diag << " inside pim.core_batch with laneCount " << coreBatchOp.getLaneCount(); return failure(); } processedOperations++; @@ -935,9 +938,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: if (auto coreOp = dyn_cast(op)) { if (auto err = emitCore(coreOp, false)) return err; - memory.recordCoreReport(emittedCoreIds.lookup(static_cast(coreOp.getCoreId())), - memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast(coreOp.getCoreId()))) - .getReportRow()); + memory.recordCoreReport( + emittedCoreIds.lookup(static_cast(coreOp.getCoreId())), + memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast(coreOp.getCoreId()))).getReportRow()); continue; } diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 8536b4a..42aa656 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -29,7 +29,7 @@ struct MemoryReportRow { bool operator==(const MemoryReportRow& other) const { return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal - && sizeGlobal == other.sizeGlobal; + && sizeGlobal == other.sizeGlobal; } }; diff --git a/src/PIM/Compiler/PimWeightEmitter.cpp b/src/PIM/Compiler/PimWeightEmitter.cpp index 2e318ee..0dbf870 100644 --- a/src/PIM/Compiler/PimWeightEmitter.cpp +++ b/src/PIM/Compiler/PimWeightEmitter.cpp @@ -10,6 +10,8 @@ #include #include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" @@ -30,19 +32,6 @@ struct DenseWeightView { int64_t offset = 0; }; -SmallVector computeRowMajorStridesForShape(ArrayRef shape) { - SmallVector strides(shape.size(), 1); - for (int64_t index = static_cast(shape.size()) - 2; index >= 0; --index) - strides[index] = strides[index + 1] * shape[index + 1]; - return strides; -} - -bool allStaticSubviewParts(memref::SubViewOp subview) { - return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); }) - && llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); }) - && llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); }); -} - FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) { SmallVector subviews; mlir::Value current = weight; @@ -55,7 +44,7 @@ FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value if ((getGlobalOp = dyn_cast(defOp))) break; if (auto subview = dyn_cast(defOp)) { - if (!allStaticSubviewParts(subview)) + if (!hasAllStaticSubviewParts(subview)) return failure(); subviews.push_back(subview); current = subview.getSource(); @@ -79,7 +68,7 @@ FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value DenseWeightView view; view.denseAttr = denseAttr; view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end()); - view.strides = computeRowMajorStridesForShape(view.shape); + view.strides = computeRowMajorStrides(view.shape); for (memref::SubViewOp subview : llvm::reverse(subviews)) { SmallVector nextStrides; diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp index d05016e..a84e118 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp @@ -4,9 +4,9 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallVector.h" #include #include @@ -47,8 +47,8 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides); } -static Value createPoolFillElement( - ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) { +static Value +createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) { if (!useMinimumValue) return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType)); @@ -65,8 +65,10 @@ static Value createPoolFillElement( llvm_unreachable("unsupported pool element type"); } -static Value createPoolFillTensor( - ConversionPatternRewriter& rewriter, Location loc, RankedTensorType tensorType, bool useMinimumValue) { +static Value createPoolFillTensor(ConversionPatternRewriter& rewriter, + Location loc, + RankedTensorType tensorType, + bool useMinimumValue) { auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue); return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement); } @@ -90,10 +92,8 @@ static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter, inputType.getDimSize(3) + padLeft + padRight}, inputType.getElementType(), inputType.getEncoding()); - SmallVector lowPads = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(padTop), - rewriter.getIndexAttr(padLeft)}; + SmallVector lowPads = { + rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padTop), rewriter.getIndexAttr(padLeft)}; SmallVector highPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padBottom), @@ -104,8 +104,8 @@ static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter, padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); - Value padValue = createPoolFillElement( - rewriter, loc, inputType.getElementType(), std::is_same_v); + Value padValue = + createPoolFillElement(rewriter, loc, inputType.getElementType(), std::is_same_v); tensor::YieldOp::create(rewriter, loc, padValue); rewriter.setInsertionPointAfter(padOp); return padOp.getResult(); @@ -279,7 +279,8 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { constexpr size_t numInputs = 1; auto computeOp = createSpatCompute(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult { - Value paddedInput = createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight); + Value paddedInput = + createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight); Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()); Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); @@ -307,8 +308,8 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { const int64_t tileChannels = std::min(xbarSize, channels - channelTile * xbarSize); auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); - Value reducedWindow = createPoolFillTensor( - rewriter, loc, tileType, std::is_same_v); + Value reducedWindow = + createPoolFillTensor(rewriter, loc, tileType, std::is_same_v); for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { Value paddedInH = windowBaseH; @@ -324,18 +325,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset); } - SmallVector offsets = {batchIndex, - rewriter.getIndexAttr(channelTile * xbarSize), - paddedInH, - paddedInW}; + SmallVector offsets = { + batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW}; SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileChannels), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - SmallVector strides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; + SmallVector strides = { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value windowValue = tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides); windowValue = materializeContiguousTile(rewriter, loc, windowValue); @@ -344,36 +341,28 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { } if constexpr (std::is_same_v) { - SmallVector scaleOffsets = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(channelTile * xbarSize), - outHeightIndex, - outWidthIndex}; + SmallVector scaleOffsets = { + rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex}; SmallVector scaleSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileChannels), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - SmallVector scaleStrides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; + SmallVector scaleStrides = { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value scaleSlice = tensor::ExtractSliceOp::create( rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides); scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice); reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice); } - SmallVector outputOffsets = {batchIndex, - rewriter.getIndexAttr(channelTile * xbarSize), - outHeightIndex, - outWidthIndex}; + SmallVector outputOffsets = { + batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex}; SmallVector outputSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileChannels), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - SmallVector outputStrides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; + SmallVector outputStrides = { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; updatedOutput = tensor::InsertSliceOp::create( rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides); } diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 7a48fca..d5508e5 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -9,12 +9,14 @@ #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" #include #include @@ -147,6 +149,73 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite rewriter.replaceOp(extractRowsOp, replacements); } +static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) { + auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType(); + auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + auto zeroAttr = DenseElementsAttr::get(tensorType, rewriter.getZeroAttr(tensorType.getElementType())); + + for (auto globalOp : moduleOp.getOps()) { + if (!globalOp.getConstant() || globalOp.getType() != memRefType || !globalOp.getInitialValue()) + continue; + if (dyn_cast(*globalOp.getInitialValue()) == zeroAttr) + return globalOp; + } + + std::string nameStem; + llvm::raw_string_ostream nameStream(nameStem); + nameStream << "__pim_zero_" << tensorType.getRank() << "d_" << tensorType.getNumElements(); + nameStream.flush(); + + std::string symbolName = nameStem; + unsigned suffix = 0; + while (SymbolTable::lookupSymbolIn(moduleOp, symbolName)) + symbolName = (nameStem + "_" + Twine(suffix++)).str(); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + return memref::GlobalOp::create(rewriter, + loc, + rewriter.getStringAttr(symbolName), + rewriter.getStringAttr("private"), + TypeAttr::get(memRefType), + zeroAttr, + rewriter.getUnitAttr(), + IntegerAttr {}); +} + +static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) { + auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType); + auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType); + auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); + auto zeroAttr = rewriter.getI32IntegerAttr(0); + auto sizeAttr = rewriter.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(tensorType))); + + if (outputBuffer->getParentOfType()) + return PimMemCopyHostToDevBatchOp::create( + rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr) + .getOutput(); + + return PimMemCopyHostToDevOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr) + .getOutput(); +} + +static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) { + auto vectorType = cast(vector.getType()); + ArrayRef shape = vectorType.getShape(); + assert(isHVectorShape(shape) && "expected a horizontal vector"); + assert(shape[1] <= static_cast(crossbarSize) && "vector width must fit in one crossbar"); + + if (shape[1] == static_cast(crossbarSize)) + return vector; + + auto paddedType = RankedTensorType::get( + {shape[0], static_cast(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding()); + Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType); + auto zeroAttr = rewriter.getI32IntegerAttr(0); + auto sizeAttr = rewriter.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(vectorType))); + return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput(); +} + static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { SmallVector concatOps; funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); }); @@ -426,54 +495,35 @@ void SpatialToPimPass::runOnOperation() { } void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { - auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void { - auto* definingOp = value.getDefiningOp(); - if (!definingOp) - return; - auto dpsDefiningOp = dyn_cast(definingOp); - if (!dpsDefiningOp) - return; - auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast(value)); - if (!tiedOperand) - return; - Value tiedValue = tiedOperand->get(); - assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use"); - tiedValue.setType(newType); - self(tiedValue, newType, self); - }; - funcOp.walk([&](PimVMMOp vmmOp) { - auto outTensorOperand = vmmOp.getOutputBuffer(); - auto resultTensor = vmmOp.getOutput(); - auto outShape = getTensorShape(outTensorOperand); - assert(isHVectorShape(outShape)); - if (outShape[1] != static_cast(crossbarSize)) { - auto newShape = SmallVector {outShape[0], static_cast(crossbarSize)}; - auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType()); - if (outTensorOperand == vmmOp.getInput()) { - rewriter.setInsertionPoint(vmmOp); - auto newOutputBuffer = - tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType()); - vmmOp.getOutputBufferMutable().assign(newOutputBuffer); - } - else { - enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain); - outTensorOperand.setType(newType); - } - resultTensor.setType(newType); + auto outputType = cast(vmmOp.getOutput().getType()); + ArrayRef outputShape = outputType.getShape(); + assert(isHVectorShape(outputShape) && "expected a horizontal vector output"); + assert(outputShape[1] <= static_cast(crossbarSize) && "output width must fit in one crossbar"); - IntegerAttr zeroAttr = rewriter.getIndexAttr(0); - IntegerAttr oneAttr = rewriter.getIndexAttr(1); - IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]); - IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]); - SmallVector offsets = {zeroAttr, zeroAttr}; - SmallVector sizes = {oldShapeZeroAttr, oldShapeOneAttr}; - SmallVector strides = {oneAttr, oneAttr}; - rewriter.setInsertionPointAfter(vmmOp); - auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides); - SmallPtrSet exceptions = {vmmOp, sliceOp}; - resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions); - } + rewriter.setInsertionPoint(vmmOp); + Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput()); + auto paddedOutputType = RankedTensorType::get( + {outputShape[0], static_cast(crossbarSize)}, outputType.getElementType(), outputType.getEncoding()); + Value paddedOutputBuffer = outputShape[1] == static_cast(crossbarSize) + ? vmmOp.getOutputBuffer() + : createEmptyTensorFromShaped(rewriter, vmmOp.getLoc(), paddedOutputType).getResult(); + vmmOp.getInputMutable().assign(paddedInput); + vmmOp.getOutputBufferMutable().assign(paddedOutputBuffer); + + vmmOp.getOutput().setType(paddedOutputType); + + if (outputShape[1] == static_cast(crossbarSize)) + return; + + SmallVector offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(outputShape[0]), rewriter.getIndexAttr(outputShape[1])}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + rewriter.setInsertionPointAfter(vmmOp); + auto sliceOp = + tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), outputType, vmmOp.getOutput(), offsets, sizes, strides); + SmallPtrSet exceptions = {vmmOp, sliceOp}; + vmmOp.getOutput().replaceAllUsesExcept(sliceOp.getResult(), exceptions); }); } diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index c97f174..07022a2 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -389,6 +389,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> { } }]; + let hasVerifier = 1; let assemblyFormat = [{ `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) }]; diff --git a/src/PIM/Dialect/Pim/PimOpsVerify.cpp b/src/PIM/Dialect/Pim/PimOpsVerify.cpp index 05cef60..e9ce9cf 100644 --- a/src/PIM/Dialect/Pim/PimOpsVerify.cpp +++ b/src/PIM/Dialect/Pim/PimOpsVerify.cpp @@ -4,6 +4,7 @@ #include "llvm/Support/LogicalResult.h" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; @@ -77,6 +78,22 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef coreI return success(); } +static FailureOr> getWeightShapeForVMM(Operation* op, size_t weightIndex) { + if (auto coreOp = op->getParentOfType()) { + if (weightIndex >= coreOp.getWeights().size()) + return failure(); + return cast(coreOp.getWeights()[weightIndex].getType()).getShape(); + } + + if (auto coreBatchOp = op->getParentOfType()) { + if (weightIndex >= coreBatchOp.getWeights().size()) + return failure(); + return cast(coreBatchOp.getWeights()[weightIndex].getType()).getShape(); + } + + return failure(); +} + } // namespace LogicalResult PimSendTensorOp::verify() { @@ -104,6 +121,47 @@ LogicalResult PimReceiveTensorBatchOp::verify() { getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch"); } +LogicalResult PimVMMOp::verify() { + if (failed(verifyCompatibleShapedTypes( + getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) + return failure(); + + auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex()); + if (failed(matrixShapeOpt)) + return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex"); + ArrayRef matrixShape = *matrixShapeOpt; + + auto vectorType = dyn_cast(getInput().getType()); + auto outputType = dyn_cast(getOutput().getType()); + if (!vectorType || !outputType) + return emitError("input and output must be shaped types"); + + ArrayRef vectorShape = vectorType.getShape(); + ArrayRef outputShape = outputType.getShape(); + + if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) + return emitError("matrix, vector and output must have rank 2"); + + int64_t N = matrixShape[0]; + int64_t M = matrixShape[1]; + if (N <= 0 || M <= 0) + return emitError("matrix shape must be (N, M) with N > 0 and M > 0"); + if (N > static_cast(crossbarSize) || M > static_cast(crossbarSize)) + return emitError("matrix dimensions must fit in one crossbar"); + + int64_t vector1 = vectorShape[0]; + int64_t vectorWidth = vectorShape[1]; + if (vector1 != 1 || vectorWidth != static_cast(crossbarSize)) + return emitError("vector shape must be (1, crossbar-size)"); + + int64_t output1 = outputShape[0]; + int64_t outputWidth = outputShape[1]; + if (output1 != 1 || outputWidth != static_cast(crossbarSize)) + return emitError("output shape must be (1, crossbar-size)"); + + return success(); +} + LogicalResult PimConcatOp::verify() { if (getInputs().empty()) return emitError("requires at least one input"); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 1fad5b5..1e724a8 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -105,6 +105,37 @@ struct MemCopyDevToHostOpInterface } }; +struct MemCopyOpInterface : DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto memCopyOp = cast(op); + + auto targetOpt = getBufferOrValue(rewriter, memCopyOp.getTarget(), options, state); + if (failed(targetOpt)) + return failure(); + + auto sourceOpt = getBufferOrValue(rewriter, memCopyOp.getSource(), options, state); + if (failed(sourceOpt)) + return failure(); + + replaceOpWithNewBufferizedOp(rewriter, + memCopyOp, + targetOpt->getType(), + *targetOpt, + *sourceOpt, + memCopyOp.getTargetOffsetAttr(), + memCopyOp.getSourceOffsetAttr(), + memCopyOp.getSizeAttr()); + return success(); + } +}; + struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return !cast(op).isDpsInit(&opOperand); @@ -626,6 +657,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) { PimMemCopyHostToDevOp::attachInterface(*ctx); PimMemCopyHostToDevBatchOp::attachInterface(*ctx); PimMemCopyDevToHostOp::attachInterface(*ctx); + PimMemCopyOp::attachInterface(*ctx); PimTransposeOp::attachInterface(*ctx); PimVMMOp::attachInterface(*ctx); diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index 00e4fc4..96286dc 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -52,9 +52,10 @@ static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) { printer << " "; printer.printOperand(op.getInput()); printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds()); - printer.printOptionalAttrDict( - op->getAttrs(), - {op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()}); + printer.printOptionalAttrDict(op->getAttrs(), + {op.getChannelIdsAttrName().getValue(), + op.getSourceCoreIdsAttrName().getValue(), + op.getTargetCoreIdsAttrName().getValue()}); printer << " : "; printer.printType(op.getInput().getType()); } @@ -62,9 +63,10 @@ static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) { template static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) { printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds()); - printer.printOptionalAttrDict( - op->getAttrs(), - {op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()}); + printer.printOptionalAttrDict(op->getAttrs(), + {op.getChannelIdsAttrName().getValue(), + op.getSourceCoreIdsAttrName().getValue(), + op.getTargetCoreIdsAttrName().getValue()}); printer << " : "; printer.printType(op.getOutput().getType()); } diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp index 8cbaeab..8593dd5 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp @@ -41,30 +41,6 @@ struct DenseSubviewKeyInfo { } // namespace -Value stripMemRefCasts(Value value) { - while (auto castOp = value.getDefiningOp()) - value = castOp.getSource(); - return value; -} - -Value stripMemRefViewOps(Value value) { - while (true) { - if (auto castOp = value.getDefiningOp()) { - value = castOp.getSource(); - continue; - } - if (auto collapseOp = value.getDefiningOp()) { - value = collapseOp.getSrc(); - continue; - } - if (auto expandOp = value.getDefiningOp()) { - value = expandOp.getSrc(); - continue; - } - return value; - } -} - memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp, Location loc, MemRefType globalType, @@ -177,48 +153,4 @@ FailureOr foldDenseSourceToType(ModuleOp moduleOp, Value sour return *denseAttr; } -FailureOr getStaticSubviewInfo(Value value) { - value = stripMemRefViewOps(value); - auto subviewOp = value.getDefiningOp(); - if (!subviewOp) - return failure(); - - auto source = stripMemRefCasts(subviewOp.getSource()); - auto sourceType = dyn_cast(source.getType()); - auto subviewType = dyn_cast(subviewOp.getType()); - if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) - return failure(); - - StaticSubviewInfo info; - info.source = source; - info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end()); - SmallVector mixedOffsets = subviewOp.getMixedOffsets(); - info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end()); - for (OpFoldResult size : subviewOp.getMixedSizes()) { - auto staticSize = getConstantIntValue(size); - if (!staticSize) - return failure(); - info.sizes.push_back(*staticSize); - } - for (OpFoldResult stride : subviewOp.getMixedStrides()) { - auto staticStride = getConstantIntValue(stride); - if (!staticStride) - return failure(); - info.strides.push_back(*staticStride); - } - return info; -} - -FailureOr> getStaticSubviewOffsets(const StaticSubviewInfo& info) { - SmallVector staticOffsets; - staticOffsets.reserve(info.offsets.size()); - for (OpFoldResult offset : info.offsets) { - auto staticOffset = getConstantIntValue(offset); - if (!staticOffset) - return failure(); - staticOffsets.push_back(*staticOffset); - } - return staticOffsets; -} - } // namespace onnx_mlir diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp index 2498658..77dac10 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp @@ -6,23 +6,12 @@ #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" + namespace onnx_mlir { -struct StaticSubviewInfo { - mlir::Value source; - llvm::SmallVector sourceShape; - llvm::SmallVector offsets; - llvm::SmallVector sizes; - llvm::SmallVector strides; -}; - -mlir::Value stripMemRefCasts(mlir::Value value); - -mlir::Value stripMemRefViewOps(mlir::Value value); - mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp, mlir::Location loc, mlir::MemRefType globalType, @@ -39,9 +28,4 @@ llvm::FailureOr getDenseGlobalValue(mlir::ModuleOp modu llvm::FailureOr foldDenseSourceToType(mlir::ModuleOp moduleOp, mlir::Value source, mlir::MemRefType resultType); -llvm::FailureOr getStaticSubviewInfo(mlir::Value value); - -/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic. -llvm::FailureOr> getStaticSubviewOffsets(const StaticSubviewInfo& info); - } // namespace onnx_mlir diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index 8d858f5..40b445a 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -9,6 +9,7 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" using namespace mlir; @@ -67,12 +68,6 @@ static bool isCodegenAddressableValue(Value value) { } static bool isConstantGlobalView(Value value) { - auto allStaticSubviewParts = [](memref::SubViewOp subview) { - return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); }) - && llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); }) - && llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); }); - }; - while (true) { Operation* defOp = value.getDefiningOp(); if (!defOp) @@ -84,7 +79,7 @@ static bool isConstantGlobalView(Value value) { && isa(*globalOp.getInitialValue()); } if (auto subview = dyn_cast(defOp)) { - if (!allStaticSubviewParts(subview)) + if (!hasAllStaticSubviewParts(subview)) return false; value = subview.getSource(); continue;