Dynamic gemm/conv
This commit is contained in:
@@ -524,6 +524,39 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
|
||||
}
|
||||
};
|
||||
|
||||
struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInterface, PimVVDMulOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto vvdmulOp = cast<PimVVDMulOp>(op);
|
||||
|
||||
auto lhsOpt = getBufferOrValue(rewriter, vvdmulOp.getLhs(), options, state);
|
||||
if (failed(lhsOpt))
|
||||
return failure();
|
||||
|
||||
auto rhsOpt = getBufferOrValue(rewriter, vvdmulOp.getRhs(), options, state);
|
||||
if (failed(rhsOpt))
|
||||
return failure();
|
||||
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, vvdmulOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVVDMulOp>(
|
||||
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpInterface<OpTy>, OpTy> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
@@ -576,6 +609,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
|
||||
PimVVMulOp::attachInterface<BinaryDstOpInterface<PimVVMulOp>>(*ctx);
|
||||
PimVVMaxOp::attachInterface<BinaryDstOpInterface<PimVVMaxOp>>(*ctx);
|
||||
PimVVDMulOp::attachInterface<VVDMulOpInterface>(*ctx);
|
||||
|
||||
PimVAvgOp::attachInterface<UnaryDstOpInterface<PimVAvgOp>>(*ctx);
|
||||
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
|
||||
|
||||
@@ -6,6 +6,7 @@ add_pim_library(SpatialOps
|
||||
SpatialOpsAsm.cpp
|
||||
SpatialOpsVerify.cpp
|
||||
SpatialOpsCanonicalization.cpp
|
||||
${PIM_SRC_ROOT}/Conversion/ONNXToSpatial/CompileTime.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||
|
||||
@@ -219,6 +219,25 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVVDMulOp : SpatOp<"vvdmul", []> {
|
||||
let summary = "Dot product between two runtime vectors";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$lhs,
|
||||
SpatTensor:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVAddOp : SpatOp<"vadd", []> {
|
||||
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
|
||||
|
||||
|
||||
@@ -249,6 +249,32 @@ LogicalResult SpatVMMOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatVVDMulOp::verify() {
|
||||
auto lhsType = dyn_cast<ShapedType>(getLhs().getType());
|
||||
auto rhsType = dyn_cast<ShapedType>(getRhs().getType());
|
||||
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
||||
if (!lhsType || !rhsType || !outputType)
|
||||
return emitError("lhs, rhs, and output must be shaped values");
|
||||
if (!lhsType.hasRank() || !rhsType.hasRank() || !outputType.hasRank())
|
||||
return emitError("lhs, rhs, and output must have ranked types");
|
||||
|
||||
ArrayRef<int64_t> lhsShape = lhsType.getShape();
|
||||
ArrayRef<int64_t> rhsShape = rhsType.getShape();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
if (lhsShape.size() != 2 || rhsShape.size() != 2 || outputShape.size() != 2)
|
||||
return emitError("lhs, rhs, and output must have rank 2");
|
||||
if (lhsType.getElementType() != rhsType.getElementType() || lhsType.getElementType() != outputType.getElementType())
|
||||
return emitError("lhs, rhs, and output must have the same element type");
|
||||
if (lhsShape != rhsShape)
|
||||
return emitError("lhs and rhs vector shapes must match");
|
||||
if (lhsShape[0] != 1 || lhsShape[1] <= 0)
|
||||
return emitError("lhs and rhs vector shape must be (1, N) with N > 0");
|
||||
if (outputShape[0] != 1 || outputShape[1] != 1)
|
||||
return emitError("output shape must be (1, 1)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatVAddOp::verify() {
|
||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
||||
return failure();
|
||||
|
||||
@@ -1019,27 +1019,6 @@ std::optional<IndexedIndexPattern> getIndexedIndexPattern(ArrayRef<int64_t> valu
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Value createAffineApplyOrConstant(
|
||||
MaterializerState& state, Location loc, AffineMap map, ValueRange operands, Operation* anchor) {
|
||||
SmallVector<Attribute> operandConstants;
|
||||
operandConstants.reserve(operands.size());
|
||||
for (Value operand : operands) {
|
||||
auto constantValue = getConstantIntValue(operand);
|
||||
if (!constantValue)
|
||||
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
|
||||
operandConstants.push_back(state.rewriter.getIndexAttr(*constantValue));
|
||||
}
|
||||
|
||||
SmallVector<Attribute> foldedResults;
|
||||
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
|
||||
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
|
||||
if (constantResult)
|
||||
return createIndexConstant(state, anchor, constantResult.getInt());
|
||||
}
|
||||
|
||||
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
|
||||
}
|
||||
|
||||
Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) {
|
||||
MLIRContext* context = state.func.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
@@ -1054,7 +1033,7 @@ Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern
|
||||
}
|
||||
|
||||
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
|
||||
return createAffineApplyOrConstant(state, loc, map, ValueRange {index}, state.func);
|
||||
return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {index}, state.func);
|
||||
}
|
||||
|
||||
Value createIndexedIndexValue(
|
||||
@@ -3396,7 +3375,7 @@ Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targe
|
||||
|
||||
int64_t laneCount = static_cast<int64_t>(targetClass.cpus.size());
|
||||
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1);
|
||||
return createAffineApplyOrConstant(state, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
|
||||
return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
|
||||
}
|
||||
|
||||
Value createBatchClassRunSourceLane(MaterializerState& state,
|
||||
|
||||
Reference in New Issue
Block a user