Dynamic gemm/conv
This commit is contained in:
@@ -6,6 +6,7 @@ add_pim_library(SpatialOps
|
||||
SpatialOpsAsm.cpp
|
||||
SpatialOpsVerify.cpp
|
||||
SpatialOpsCanonicalization.cpp
|
||||
${PIM_SRC_ROOT}/Conversion/ONNXToSpatial/CompileTime.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||
|
||||
@@ -219,6 +219,25 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVVDMulOp : SpatOp<"vvdmul", []> {
|
||||
let summary = "Dot product between two runtime vectors";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$lhs,
|
||||
SpatTensor:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVAddOp : SpatOp<"vadd", []> {
|
||||
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
|
||||
|
||||
|
||||
@@ -249,6 +249,32 @@ LogicalResult SpatVMMOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatVVDMulOp::verify() {
|
||||
auto lhsType = dyn_cast<ShapedType>(getLhs().getType());
|
||||
auto rhsType = dyn_cast<ShapedType>(getRhs().getType());
|
||||
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
||||
if (!lhsType || !rhsType || !outputType)
|
||||
return emitError("lhs, rhs, and output must be shaped values");
|
||||
if (!lhsType.hasRank() || !rhsType.hasRank() || !outputType.hasRank())
|
||||
return emitError("lhs, rhs, and output must have ranked types");
|
||||
|
||||
ArrayRef<int64_t> lhsShape = lhsType.getShape();
|
||||
ArrayRef<int64_t> rhsShape = rhsType.getShape();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
if (lhsShape.size() != 2 || rhsShape.size() != 2 || outputShape.size() != 2)
|
||||
return emitError("lhs, rhs, and output must have rank 2");
|
||||
if (lhsType.getElementType() != rhsType.getElementType() || lhsType.getElementType() != outputType.getElementType())
|
||||
return emitError("lhs, rhs, and output must have the same element type");
|
||||
if (lhsShape != rhsShape)
|
||||
return emitError("lhs and rhs vector shapes must match");
|
||||
if (lhsShape[0] != 1 || lhsShape[1] <= 0)
|
||||
return emitError("lhs and rhs vector shape must be (1, N) with N > 0");
|
||||
if (outputShape[0] != 1 || outputShape[1] != 1)
|
||||
return emitError("output shape must be (1, 1)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatVAddOp::verify() {
|
||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
||||
return failure();
|
||||
|
||||
@@ -1019,27 +1019,6 @@ std::optional<IndexedIndexPattern> getIndexedIndexPattern(ArrayRef<int64_t> valu
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Value createAffineApplyOrConstant(
|
||||
MaterializerState& state, Location loc, AffineMap map, ValueRange operands, Operation* anchor) {
|
||||
SmallVector<Attribute> operandConstants;
|
||||
operandConstants.reserve(operands.size());
|
||||
for (Value operand : operands) {
|
||||
auto constantValue = getConstantIntValue(operand);
|
||||
if (!constantValue)
|
||||
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
|
||||
operandConstants.push_back(state.rewriter.getIndexAttr(*constantValue));
|
||||
}
|
||||
|
||||
SmallVector<Attribute> foldedResults;
|
||||
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
|
||||
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
|
||||
if (constantResult)
|
||||
return createIndexConstant(state, anchor, constantResult.getInt());
|
||||
}
|
||||
|
||||
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
|
||||
}
|
||||
|
||||
Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) {
|
||||
MLIRContext* context = state.func.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
@@ -1054,7 +1033,7 @@ Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern
|
||||
}
|
||||
|
||||
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
|
||||
return createAffineApplyOrConstant(state, loc, map, ValueRange {index}, state.func);
|
||||
return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {index}, state.func);
|
||||
}
|
||||
|
||||
Value createIndexedIndexValue(
|
||||
@@ -3396,7 +3375,7 @@ Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targe
|
||||
|
||||
int64_t laneCount = static_cast<int64_t>(targetClass.cpus.size());
|
||||
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1);
|
||||
return createAffineApplyOrConstant(state, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
|
||||
return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
|
||||
}
|
||||
|
||||
Value createBatchClassRunSourceLane(MaterializerState& state,
|
||||
|
||||
Reference in New Issue
Block a user