add support for softmax, resize, split, gather
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -455,4 +455,27 @@ def PimVSigmOp : PimOp<"vsigm", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVSoftmaxOp : PimOp<"vsoftmax", [DestinationStyleOpInterface]> {
|
||||
let summary = "Softmax over the full input vector";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // PIM_DIALECT_H
|
||||
|
||||
@@ -273,6 +273,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
|
||||
PimVTanhOp::attachInterface<UnaryDstOpInterface<PimVTanhOp>>(*ctx);
|
||||
PimVSigmOp::attachInterface<UnaryDstOpInterface<PimVSigmOp>>(*ctx);
|
||||
PimVSoftmaxOp::attachInterface<UnaryDstOpInterface<PimVSoftmaxOp>>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -272,6 +272,22 @@ def SpatSigmoidOp : SpatOp<"sigmoid", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatSoftmaxOp : SpatOp<"softmax", []> {
|
||||
let summary = "Softmax over the full input tensor slice";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatReluOp : SpatOp<"relu", []> {
|
||||
let summary = "Element-wise ReLU activation";
|
||||
|
||||
|
||||
@@ -202,9 +202,9 @@ private:
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
|
||||
for (auto users : oldWeightedCompute->getUsers())
|
||||
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
|
||||
funcRet.setOperand(0, newWeightedCompute.getResult(0));
|
||||
for (auto& use : llvm::make_early_inc_range(oldWeightedCompute->getUses()))
|
||||
if (isa<func::ReturnOp>(use.getOwner()))
|
||||
use.assign(newWeightedCompute.getResult(0));
|
||||
|
||||
oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute});
|
||||
return {cast<SpatWeightedCompute>(newWeightedCompute), computeValueResults};
|
||||
|
||||
Reference in New Issue
Block a user