standardize spatial and pim dialects
remove old unused stuff
This commit is contained in:
@@ -16,7 +16,7 @@ class SpatOp<string mnemonic, list<Trait> traits = []> :
|
||||
Op<SpatialDialect, mnemonic, traits>;
|
||||
|
||||
// TODO maybe remove and use AnyRankedTensor directly
|
||||
def SpatTensor:
|
||||
def SpatTensor :
|
||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
||||
|
||||
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
|
||||
@@ -28,8 +28,12 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> {
|
||||
let summary = "Virtual channel type";
|
||||
}
|
||||
|
||||
def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Compute operation, with constant weights already attached";
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Compute region with attached constant weights";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<SpatTensor>:$weights,
|
||||
@@ -49,7 +53,9 @@ def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegment
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatYieldOp: SpatOp<"yield", [Terminator]> {
|
||||
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
||||
let summary = "Yield results from a compute region";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<SpatTensor>:$outputs
|
||||
);
|
||||
@@ -60,12 +66,14 @@ def SpatYieldOp: SpatOp<"yield", [Terminator]> {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Data movement operations
|
||||
// Communication
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatChannelNewOp: SpatOp<"channel_new", []> {
|
||||
def SpatChannelNewOp : SpatOp<"channel_new", []> {
|
||||
let summary = "Create a new virtual channel";
|
||||
|
||||
let results = (outs
|
||||
SpatChannelType:$new_channel
|
||||
SpatChannelType:$channel
|
||||
);
|
||||
|
||||
let builders = [
|
||||
@@ -79,108 +87,74 @@ def SpatChannelNewOp: SpatOp<"channel_new", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendOp: SpatOp<"channel_send", []> {
|
||||
def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
||||
let summary = "Send a tensor through a channel";
|
||||
|
||||
let arguments = (ins
|
||||
SpatChannelType: $channel,
|
||||
SpatTensor: $data
|
||||
SpatChannelType:$channel,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$data `to` $channel attr-dict `:` `(` type($data) `->` type($channel) `)`
|
||||
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveOp: SpatOp<"channel_receive", []> {
|
||||
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
let summary = "Receive a tensor from a channel";
|
||||
|
||||
let arguments = (ins
|
||||
SpatChannelType: $channel
|
||||
SpatChannelType:$channel
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor: $data
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$channel attr-dict `:` `(` type($channel) `->` type($data) `)`
|
||||
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
|
||||
let summary = "Broadcast a tensor through a shared channel buffer";
|
||||
|
||||
let arguments = (ins
|
||||
SpatChannelType: $channel,
|
||||
SpatTensor: $data
|
||||
SpatChannelType:$channel,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
|
||||
let summary = "Receive a tensor from a shared channel buffer";
|
||||
|
||||
let arguments = (ins
|
||||
SpatChannelType: $channel
|
||||
SpatChannelType:$channel
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor: $data
|
||||
SpatTensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Math operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatConstantOp: SpatOp<"constant", []> {
|
||||
let description = [{
|
||||
"Constant value, should be used for weights and biases"
|
||||
let assemblyFormat = [{
|
||||
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyAttr: $value,
|
||||
BoolAttr: $shouldAllocate
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor: $out
|
||||
);
|
||||
}
|
||||
|
||||
def SpatWeightedVMMOp: SpatOp<"Wvmm", []> {
|
||||
let summary = "Vector-matrix-Multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Math
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
|
||||
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr: $weightIndex,
|
||||
SpatTensor:$vector
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
// TODO: Verifier that checks it is within a WeightedCompute operation,
|
||||
// that the weightIndex is valid, and that the matrix is of the right size.
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SpatWeightedMVMOp: SpatOp<"Wmvm", []> {
|
||||
let summary = "Matrix-vector multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr: $weightIndex,
|
||||
SpatTensor:$vector
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
// TODO: Verifier that checks it is within a WeightedCompute operation,
|
||||
// that the weightIndex is valid, and that the matrix is of the right size.
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
def SpatVAddOp: SpatOp<"vadd", []> {
|
||||
let summary = "Element-wise add between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor: $a,
|
||||
SpatTensor: $b
|
||||
I32Attr:$weightIndex,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -190,76 +164,68 @@ def SpatVAddOp: SpatOp<"vadd", []> {
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVMulOp: SpatOp<"vmul", []> {
|
||||
let summary = "Element-wise multiplication between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor: $a,
|
||||
SpatTensor: $b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
//let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVDivOp: SpatOp<"vdiv", []> {
|
||||
let summary = "Element-wise division between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
||||
def SpatWeightedMVMOp : SpatOp<"Wmvm", []> {
|
||||
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$a,
|
||||
SpatTensor:$b
|
||||
I32Attr:$weightIndex,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
//let hasVerifier = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
//TODO: remove
|
||||
def SpatVSDivOp: SpatOp<"vsdiv", []> {
|
||||
|
||||
let summary = "Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)";
|
||||
def SpatVAddOp : SpatOp<"vadd", []> {
|
||||
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$dividend,
|
||||
SpatTensor:$divisor
|
||||
SpatTensor:$lhs,
|
||||
SpatTensor:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatSumOp: SpatOp<"sum", []> {
|
||||
let summary = "Sum all the elements in the input tensors into a single scalar wrapped in tensor for convenience";
|
||||
def SpatVMulOp : SpatOp<"vmul", []> {
|
||||
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor: $input
|
||||
SpatTensor:$lhs,
|
||||
SpatTensor:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatSigmoidOp: SpatOp<"sigmoid", []> {
|
||||
def SpatSumOp : SpatOp<"sum", []> {
|
||||
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input
|
||||
);
|
||||
@@ -267,9 +233,15 @@ def SpatSigmoidOp: SpatOp<"sigmoid", []> {
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatReluOp: SpatOp<"relu", []> {
|
||||
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
|
||||
let summary = "Element-wise sigmoid activation";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input
|
||||
);
|
||||
@@ -277,68 +249,34 @@ def SpatReluOp: SpatOp<"relu", []> {
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
def SpatVMaxOp: SpatOp<"vmax", []> {
|
||||
|
||||
let summary = "Element-wise max function";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor: $a,
|
||||
SpatTensor: $b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SpatApplyFiltersOp : SpatOp<"apply_filters", []> {
|
||||
let summary = "Apply multiple crossbar weights to a convolutional input tile.";
|
||||
let description = [{
|
||||
Applies a variable number of crossbar weights to a single large image tensor tile,
|
||||
producing a corresponding output tile. This essentially encapsulates a big for loop
|
||||
over all pixels in the input tile, where each pixel is multiplied by all the weights
|
||||
in the operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I64ArrayAttr: $weightIndices,
|
||||
I64ArrayAttr: $xKernelPositions,
|
||||
I64ArrayAttr: $yKernelPositions,
|
||||
SpatTensor: $input
|
||||
);
|
||||
let results = (outs SpatTensor);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input) `->` type(results)
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Other operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatImgConcatOp: SpatOp<"img_concat", []> {
|
||||
|
||||
let summary = "Concatenate pixel tiles into a single image";
|
||||
|
||||
let description = [{
|
||||
Concatenate pixel tiles into a single image:
|
||||
1. First, concatenate the pixel tiles along the "channel" axis (axis 1).
|
||||
2. Next, concatenate the pixel tiles along the "width" axis (axis 2).
|
||||
3. Finally, concatenate the pixel tiles along the "height" axis (axis 3).
|
||||
|
||||
The input tiles should be provided in a specific order:
|
||||
start from the top left pixel,
|
||||
then continue with the pixel on its right,
|
||||
and once you finish the first row of pixels, go to the next row.
|
||||
}];
|
||||
def SpatReluOp : SpatOp<"relu", []> {
|
||||
let summary = "Element-wise ReLU activation";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<SpatTensor>:$inputs
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVMaxOp : SpatOp<"vmax", []> {
|
||||
let summary = "Element-wise max between two tensors";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$lhs,
|
||||
SpatTensor:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -347,9 +285,9 @@ def SpatImgConcatOp: SpatOp<"img_concat", []> {
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::Value getInputTile(size_t x, size_t y, size_t tile);
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // SPATIAL_DIALECT_H
|
||||
#endif // SPATIAL_DIALECT_H
|
||||
|
||||
@@ -136,7 +136,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getVector().getType().getShape();
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
|
||||
/* Two possible accepted shapes:
|
||||
@@ -157,7 +157,7 @@ LogicalResult SpatWeightedVMMOp::verify() {
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getVector().getType().getShape();
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
|
||||
/* Accepted shape:
|
||||
@@ -200,59 +200,6 @@ LogicalResult SpatVMaxOp::verify() {
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatImgConcatOp::verify() {
|
||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
||||
size_t img_w = getImageWidth(imgShape);
|
||||
size_t img_h = getImageHeight(imgShape);
|
||||
size_t img_c = getImageChannel(imgShape);
|
||||
|
||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
||||
size_t channelTileRest = img_c % crossbarSize;
|
||||
|
||||
auto operands = getOperands();
|
||||
|
||||
// Check number of operands
|
||||
if (img_w * img_h * channelTiles != operands.size())
|
||||
return emitError("Number of operands does not match output image size");
|
||||
|
||||
// For each output pixel, check that the inputTiles have a correct shape
|
||||
for (size_t x = 0; x < img_w; x++) {
|
||||
for (size_t y = 0; y < img_h; y++) {
|
||||
size_t channel_counts = 0;
|
||||
for (size_t t = 0; t < channelTiles; t++) {
|
||||
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
|
||||
if (!inputShape)
|
||||
return emitError("Invalid input type, must be ShapedType");
|
||||
|
||||
// N == W == H == 1
|
||||
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
|
||||
return emitError("Invalid input shape: N,W,H must all be 1");
|
||||
|
||||
size_t inputChannels = getImageChannel(inputShape);
|
||||
|
||||
// Check the number of channels in this tile are correct:
|
||||
// - CASE1: last tile of pixel, if there is some rest it must match that
|
||||
// - CASE2: common case, the channel count is exactly the crossbarSize
|
||||
if (t == channelTiles - 1 && channelTileRest != 0) {
|
||||
if (inputChannels != channelTileRest)
|
||||
return emitError("Invalid channel count for last tile of pixel");
|
||||
}
|
||||
else {
|
||||
if (inputChannels != crossbarSize)
|
||||
return emitError("Invalid channel count for some pixel tile");
|
||||
}
|
||||
|
||||
channel_counts += inputChannels;
|
||||
}
|
||||
|
||||
if (channel_counts != img_c)
|
||||
emitError("Invalid number of channels for some pixel");
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedCompute::verify() {
|
||||
// Check that it has a terminator, it is a yieldOp, and it has a single
|
||||
// operand with the same type as the result
|
||||
@@ -308,22 +255,6 @@ LogicalResult SpatWeightedCompute::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
|
||||
auto operands = getOperands();
|
||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
||||
size_t img_w = getImageWidth(imgShape);
|
||||
size_t img_h = getImageHeight(imgShape);
|
||||
size_t img_c = getImageChannel(imgShape);
|
||||
|
||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
||||
|
||||
assert(tile < channelTiles);
|
||||
assert(x < img_w);
|
||||
assert(y < img_h);
|
||||
|
||||
return operands[tile + x * channelTiles + y * img_w * channelTiles];
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getDstOut();
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||
@@ -202,7 +202,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
|
||||
|
||||
memrefOperands.push_back(outputTensor);
|
||||
|
||||
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
|
||||
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
@@ -243,7 +243,7 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
|
||||
cast<OpTy>(op).getWeightIndexAttr(),
|
||||
memrefOperand,
|
||||
outputTensor)
|
||||
.getOutRes();
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
@@ -289,7 +289,7 @@ struct ChannelReceiveOpInterface
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
.getOut();
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
@@ -403,7 +403,7 @@ struct ChannelBroadcastReceiveOpInterface
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(outputSize));
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst());
|
||||
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getOutput());
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -490,84 +490,6 @@ struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, S
|
||||
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||
|
||||
// Create a new bufferizable op interface for the apply filters operation.
|
||||
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
|
||||
|
||||
// One operand ($input) is read from. All other inputs are only written to.
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
|
||||
// Operand 0: $input
|
||||
// Operand 1: $outBuf
|
||||
// Operand 2: $accumBuf
|
||||
return opOperand.getOperandNumber() == 0;
|
||||
}
|
||||
|
||||
// One input ($accumBuf) is written to. All other inputs are only read.
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
|
||||
// Operand 0: $input
|
||||
// Operand 1: $outBuf
|
||||
// Operand 2: $accumBuf
|
||||
return opOperand.getOperandNumber() == 2;
|
||||
}
|
||||
|
||||
// No operands are aliased with any other operands.
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Bufferize the operation.
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
// Get the input tensor buffer.
|
||||
auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state);
|
||||
|
||||
if (failed(inputBuffer))
|
||||
return failure();
|
||||
|
||||
// Create a new buffer for the output tensor.
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
// Create a new buffer for the accumulation buffer.
|
||||
// To do this, create a new allocation operation. Size must be axbx1x1,
|
||||
// where axbxcxd is the size of the output tensor. Since the shape is
|
||||
// different, we can't immediately use createEmptyFromType, we first need to
|
||||
// create the shape of the accumulation buffer.
|
||||
auto accumShape = llvm::to_vector<4>(cast<ShapedType>(op->getResult(0).getType()).getShape());
|
||||
|
||||
// Set the last two dimensions to 1.
|
||||
accumShape[accumShape.size() - 1] = 1;
|
||||
accumShape[accumShape.size() - 2] = 1;
|
||||
|
||||
auto accumType = MemRefType::get(accumShape, cast<ShapedType>(op->getResult(0).getType()).getElementType());
|
||||
|
||||
auto accumBuffer = createEmptyFromType(accumType, op->getLoc(), rewriter);
|
||||
|
||||
// Bufferize the operation.
|
||||
auto weightIndices = cast<SpatApplyFiltersOp>(op).getWeightIndicesAttr();
|
||||
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
|
||||
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
|
||||
|
||||
Value bufferized = pim::PimApplyFiltersOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
weightIndices,
|
||||
xKernelPositions,
|
||||
yKernelPositions,
|
||||
*inputBuffer,
|
||||
outputTensor,
|
||||
accumBuffer);
|
||||
|
||||
// Replace the operation with the bufferized value.
|
||||
replaceOpWithBufferizedValues(rewriter, op, bufferized);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
|
||||
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
|
||||
@@ -580,7 +502,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
|
||||
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
|
||||
SpatApplyFiltersOp::attachInterface<ApplyFiltersOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user