From a3dd87f15e3656ee2bec4820ae72a2a4e5662b40 Mon Sep 17 00:00:00 2001 From: "drCathieSo.eth" Date: Sat, 29 Jun 2024 03:40:31 +0800 Subject: [PATCH] Adding Gemm and ArgMax operators to candle-onnx (#2231) * feat(gemm): implement Gemm operator in candle-onnx * feat(onnx): Add support for ArgMax operator in candle-onnx * Apply rustfmt. * Remove argmax as it was already present. --------- Co-authored-by: Laurent --- candle-onnx/src/eval.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 10a3b9377..f7203b36f 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1274,6 +1274,30 @@ fn simple_eval_( let output = candle_nn::ops::leaky_relu(input, alpha.into())?; values.insert(node.output[0].clone(), output); } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm + "Gemm" => { + let a = get(&node.input[0])?; + let b = get(&node.input[1])?; + let c = get(&node.input[2])?; + + let alpha = get_attr_opt::(node, "alpha")?.copied().unwrap_or(1.0); + let beta = get_attr_opt::(node, "beta")?.copied().unwrap_or(1.0); + + let alpha = Tensor::full(alpha, a.shape(), &Device::Cpu)?; + let beta = Tensor::full(beta, c.shape(), &Device::Cpu)?; + + let trans_a = get_attr_opt::(node, "transA")?.copied().unwrap_or(0); + let trans_b = get_attr_opt::(node, "transB")?.copied().unwrap_or(0); + + let a = if trans_a == 0 { a.clone() } else { a.t()? }; + let b = if trans_b == 0 { b.clone() } else { b.t()? }; + + let output = a + .broadcast_mul(&alpha)? + .broadcast_matmul(&b)? + .broadcast_add(&c.broadcast_mul(&beta)?)?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } }