diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index c152f31fb0..e7e3e129d6 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -175,7 +175,7 @@ impl Tensor { // the backprop graph of the backprop itself. This would be an issue for second order // derivatives but these are out of scope at the moment. let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b); - let grad = if do_not_detach { grad } else { grad.detach()? }; + let grad = if do_not_detach { grad } else { grad.detach() }; if let Some(op) = node.op() { match op { Op::Binary(lhs, rhs, BinaryOp::Add) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 5f0b6df919..8596c95773 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1882,9 +1882,9 @@ impl Tensor { /// this new node. The storage of this tensor is shared with the initial tensor. /// /// If the tensor is already detached from the computation graph, the same tensor is returned. - pub fn detach(&self) -> Result { + pub fn detach(&self) -> Tensor { if self.op.is_none() && !self.is_variable { - Ok(self.clone()) + self.clone() } else { let tensor_ = Tensor_ { id: TensorId::new(), @@ -1895,7 +1895,7 @@ impl Tensor { dtype: self.dtype, device: self.device.clone(), }; - Ok(Tensor(Arc::new(tensor_))) + Tensor(Arc::new(tensor_)) } } diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs index 61800bf33d..bdf8da4aed 100644 --- a/candle-core/src/variable.rs +++ b/candle-core/src/variable.rs @@ -107,6 +107,10 @@ impl Var { Ok(Self(inner)) } + pub fn as_detached_tensor(&self) -> Tensor { + self.0.detach() + } + pub fn as_tensor(&self) -> &Tensor { &self.0 } diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs index 1ce4889e6b..5309eaf669 100644 --- a/candle-examples/examples/reinforcement-learning/ddpg.rs +++ b/candle-examples/examples/reinforcement-learning/ddpg.rs @@ -411,7 +411,7 @@ impl DDPG<'_> { pub fn actions(&mut self, state: &Tensor) -> Result { let actions = self .actor - .forward(&state.detach()?.unsqueeze(0)?)? + .forward(&state.detach().unsqueeze(0)?)? .squeeze(0)?; let actions = if self.train { (actions + self.ou_noise.sample()?)? diff --git a/candle-examples/examples/reinforcement-learning/policy_gradient.rs b/candle-examples/examples/reinforcement-learning/policy_gradient.rs index 044cbfcd2f..6c355fe62f 100644 --- a/candle-examples/examples/reinforcement-learning/policy_gradient.rs +++ b/candle-examples/examples/reinforcement-learning/policy_gradient.rs @@ -74,7 +74,7 @@ pub fn run() -> Result<()> { loop { let action = { let action_probs: Vec = - softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)? + softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)? .squeeze(0)? .to_vec1()?; weighted_sample(action_probs, &mut rng)? as i64 @@ -109,7 +109,7 @@ pub fn run() -> Result<()> { let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)? .to_dtype(DType::F32)? - .detach()?; + .detach(); let actions_mask = { let actions: Vec = steps.iter().map(|s| s.action).collect(); @@ -126,12 +126,12 @@ pub fn run() -> Result<()> { .unwrap() }) .collect(); - Tensor::stack(&actions_mask, 0)?.detach()? + Tensor::stack(&actions_mask, 0)?.detach() }; let states = { let states: Vec = steps.into_iter().map(|s| s.state).collect(); - Tensor::stack(&states, 0)?.detach()? + Tensor::stack(&states, 0)?.detach() }; let log_probs = actions_mask diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 856c2c7a21..4c67961d06 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -262,9 +262,19 @@ impl BatchNorm { let target_shape = target_shape.as_slice(); let x = x - .broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)? + .broadcast_sub( + &self + .running_mean + .as_detached_tensor() + .reshape(target_shape)?, + )? .broadcast_div( - &(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?, + &(self + .running_var + .as_detached_tensor() + .reshape(target_shape)? + + self.eps)? + .sqrt()?, )?; match &self.weight_and_bias { diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index 37b8fe8c68..aef0707d51 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -88,23 +88,27 @@ class QTensor: Dequantizes the tensor. """ pass + @property def ggml_dtype(self) -> str: """ Gets the tensors quantized dtype. """ pass + def matmul_t(self, lhs: Tensor) -> Tensor: """ Performs a quantized matrix multiplication, with the quantized tensor as the right hand side. """ pass + @property def rank(self) -> int: """ Gets the rank of the tensor. """ pass + @property def shape(self) -> Tuple[int]: """ @@ -119,178 +123,213 @@ class Tensor: def __init__(self, data: _ArrayLike): pass + def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Add a scalar to a tensor or two tensors together. """ pass + def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Compare a tensor with a scalar or one tensor with another. """ pass + def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Compare a tensor with a scalar or one tensor with another. """ pass + def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor": """ Return a slice of a tensor. """ pass + def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Compare a tensor with a scalar or one tensor with another. """ pass + def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Compare a tensor with a scalar or one tensor with another. """ pass + def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Compare a tensor with a scalar or one tensor with another. """ pass + def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Multiply a tensor by a scalar or one tensor by another. """ pass + def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Compare a tensor with a scalar or one tensor with another. """ pass + def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Add a scalar to a tensor or two tensors together. """ pass + def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor": """ Compare a tensor with a scalar or one tensor with another. """ pass + def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Multiply a tensor by a scalar or one tensor by another. """ pass + def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Subtract a scalar from a tensor or one tensor from another. """ pass + def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Divide a tensor by a scalar or one tensor by another. """ pass + def abs(self) -> Tensor: """ Performs the `abs` operation on the tensor. """ pass + def argmax_keepdim(self, dim: int) -> Tensor: """ Returns the indices of the maximum value(s) across the selected dimension. """ pass + def argmin_keepdim(self, dim: int) -> Tensor: """ Returns the indices of the minimum value(s) across the selected dimension. """ pass + def broadcast_add(self, rhs: Tensor) -> Tensor: """ Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. """ pass + def broadcast_as(self, *shape: Shape) -> Tensor: """ Broadcasts the tensor to the given shape. """ pass + def broadcast_div(self, rhs: Tensor) -> Tensor: """ Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. """ pass + def broadcast_left(self, *shape: Shape) -> Tensor: """ Broadcasts the tensor to the given shape, adding new dimensions on the left. """ pass + def broadcast_mul(self, rhs: Tensor) -> Tensor: """ Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. """ pass + def broadcast_sub(self, rhs: Tensor) -> Tensor: """ Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. """ pass + def contiguous(self) -> Tensor: """ Makes the tensor contiguous in memory. """ pass + def copy(self) -> Tensor: """ Returns a copy of the tensor. """ pass + def cos(self) -> Tensor: """ Performs the `cos` operation on the tensor. """ pass + def detach(self) -> Tensor: """ Detach the tensor from the computation graph. """ pass + @property def device(self) -> Device: """ Gets the tensor's device. """ pass + @property def dtype(self) -> DType: """ Gets the tensor's dtype. """ pass + def exp(self) -> Tensor: """ Performs the `exp` operation on the tensor. """ pass + def flatten_all(self) -> Tensor: """ Flattens the tensor into a 1D tensor. """ pass + def flatten_from(self, dim: int) -> Tensor: """ Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension. """ pass + def flatten_to(self, dim: int) -> Tensor: """ Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive). """ pass + def get(self, index: int) -> Tensor: """ Gets the value at the specified index. """ pass + def index_select(self, rhs: Tensor, dim: int) -> Tensor: """ Select values for the input tensor at the target indexes across the specified dimension. @@ -302,161 +341,192 @@ class Tensor: tensor. """ pass + def is_contiguous(self) -> bool: """ Returns true if the tensor is contiguous in C order. """ pass + def is_fortran_contiguous(self) -> bool: """ Returns true if the tensor is contiguous in Fortran order. """ pass + def log(self) -> Tensor: """ Performs the `log` operation on the tensor. """ pass + def matmul(self, rhs: Tensor) -> Tensor: """ Performs a matrix multiplication between the two tensors. """ pass + def max_keepdim(self, dim: int) -> Tensor: """ Gathers the maximum value across the selected dimension. """ pass + def mean_all(self) -> Tensor: """ Returns the mean of the tensor. """ pass + def min_keepdim(self, dim: int) -> Tensor: """ Gathers the minimum value across the selected dimension. """ pass + def narrow(self, dim: int, start: int, len: int) -> Tensor: """ Returns a new tensor that is a narrowed version of the input, the dimension `dim` ranges from `start` to `start + len`. """ pass + @property def nelement(self) -> int: """ Gets the tensor's element count. """ pass + def powf(self, p: float) -> Tensor: """ Performs the `pow` operation on the tensor with the given exponent. """ pass + def quantize(self, quantized_dtype: str) -> QTensor: """ Quantize the tensor. """ pass + @property def rank(self) -> int: """ Gets the tensor's rank. """ pass + def recip(self) -> Tensor: """ Get the `recip` of the tensor. """ pass + def reshape(self, *shape: Shape) -> Tensor: """ Reshapes the tensor to the given shape. """ pass + @property def shape(self) -> Tuple[int]: """ Gets the tensor's shape. """ pass + def sin(self) -> Tensor: """ Performs the `sin` operation on the tensor. """ pass + def sqr(self) -> Tensor: """ Squares the tensor. """ pass + def sqrt(self) -> Tensor: """ Calculates the square root of the tensor. """ pass + def squeeze(self, dim: int) -> Tensor: """ Creates a new tensor with the specified dimension removed if its size was one. """ pass + @property def stride(self) -> Tuple[int]: """ Gets the tensor's strides. """ pass + def sum_all(self) -> Tensor: """ Returns the sum of the tensor. """ pass + def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor: """ Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions. """ pass + def t(self) -> Tensor: """ Transposes the tensor. """ pass + def to(self, *args, **kwargs) -> Tensor: """ Performs Tensor dtype and/or device conversion. """ pass + def to_device(self, device: Union[str, Device]) -> Tensor: """ Move the tensor to a new device. """ pass + def to_dtype(self, dtype: Union[str, DType]) -> Tensor: """ Convert the tensor to a new dtype. """ pass + def to_torch(self) -> torch.Tensor: """ Converts candle's tensor to pytorch's tensor """ pass + def transpose(self, dim1: int, dim2: int) -> Tensor: """ Returns a tensor that is a transposed version of the input, the given dimensions are swapped. """ pass + def unsqueeze(self, dim: int) -> Tensor: """ Creates a new tensor with a dimension of size one inserted at the specified position. """ pass + def values(self) -> _ArrayLike: """ Gets the tensor's data as a Python scalar or array-like object. """ pass + def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor: """ Returns a tensor with the same shape as the input tensor, the values are taken from diff --git a/candle-pyo3/py_src/candle/nn/container.py b/candle-pyo3/py_src/candle/nn/container.py index 6ece31b6db..963a8a4a6e 100644 --- a/candle-pyo3/py_src/candle/nn/container.py +++ b/candle-pyo3/py_src/candle/nn/container.py @@ -57,12 +57,10 @@ class Sequential(Module): _modules: Dict[str, Module] # type: ignore[assignment] @overload - def __init__(self, *args: Module) -> None: - ... + def __init__(self, *args: Module) -> None: ... @overload - def __init__(self, arg: "OrderedDict[str, Module]") -> None: - ... + def __init__(self, arg: "OrderedDict[str, Module]") -> None: ... def __init__(self, *args): super().__init__() diff --git a/candle-pyo3/py_src/candle/nn/module.py b/candle-pyo3/py_src/candle/nn/module.py index 514d92b86e..972d9a91df 100644 --- a/candle-pyo3/py_src/candle/nn/module.py +++ b/candle-pyo3/py_src/candle/nn/module.py @@ -204,12 +204,10 @@ def named_buffers( T_destination = TypeVar("T_destination", bound=Dict[str, Any]) @overload - def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: - ... + def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ... @overload - def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: - ... + def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ... def state_dict(self, *args, destination=None, prefix="", keep_vars=False): r"""Returns a dictionary containing references to the whole state of the module. @@ -586,12 +584,10 @@ def to( self: T, device: str = ..., dtype: Optional[Union[DType, str]] = ..., - ) -> T: - ... + ) -> T: ... @overload - def to(self: T, dtype: Union[DType, str]) -> T: - ... + def to(self: T, dtype: Union[DType, str]) -> T: ... def to(self, *args, **kwargs): r"""Moves and/or casts the parameters and buffers. diff --git a/candle-pyo3/py_src/candle/nn/normalization.py b/candle-pyo3/py_src/candle/nn/normalization.py index 67510a24bb..61d29c51f0 100644 --- a/candle-pyo3/py_src/candle/nn/normalization.py +++ b/candle-pyo3/py_src/candle/nn/normalization.py @@ -14,6 +14,7 @@ class LayerNorm(Module): math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta """ + __constants__ = ["normalized_shape", "eps"] normalized_shape: Tuple[int, ...] eps: float diff --git a/candle-pyo3/py_src/candle/onnx/__init__.pyi b/candle-pyo3/py_src/candle/onnx/__init__.pyi index 8ce1b3aaca..a23cd2f0d5 100644 --- a/candle-pyo3/py_src/candle/onnx/__init__.pyi +++ b/candle-pyo3/py_src/candle/onnx/__init__.pyi @@ -11,59 +11,69 @@ class ONNXModel: def __init__(self, path: str): pass + @property def doc_string(self) -> str: """ The doc string of the model. """ pass + @property def domain(self) -> str: """ The domain of the operator set of the model. """ pass + def initializers(self) -> Dict[str, Tensor]: """ Get the weights of the model. """ pass + @property def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]: """ The inputs of the model. """ pass + @property def ir_version(self) -> int: """ The version of the IR this model targets. """ pass + @property def model_version(self) -> int: """ The version of the model. """ pass + @property def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]: """ The outputs of the model. """ pass + @property def producer_name(self) -> str: """ The producer of the model. """ pass + @property def producer_version(self) -> str: """ The version of the producer of the model. """ pass + def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: """ Run the model on the given inputs. @@ -81,6 +91,7 @@ class ONNXTensorDescription: The data type of the tensor. """ pass + @property def shape(self) -> Tuple[Union[int, str, Any]]: """ diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index ca40687607..7b9a741340 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -938,8 +938,8 @@ impl PyTensor { /// Detach the tensor from the computation graph. /// &RETURNS&: Tensor - fn detach(&self) -> PyResult { - Ok(PyTensor(self.0.detach().map_err(wrap_err)?)) + fn detach(&self) -> Self { + PyTensor(self.0.detach()) } /// Returns a copy of the tensor. diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py index c459ebb39e..165941bd79 100644 --- a/candle-pyo3/stub.py +++ b/candle-pyo3/stub.py @@ -189,7 +189,6 @@ def do_black(content, is_pyi): line_length=119, is_pyi=is_pyi, string_normalization=True, - experimental_string_processing=False, ) try: return black.format_file_contents(content, fast=True, mode=mode) diff --git a/candle-wasm-examples/segment-anything/README.md b/candle-wasm-examples/segment-anything/README.md index 04ff203392..f8d8ad9de6 100644 --- a/candle-wasm-examples/segment-anything/README.md +++ b/candle-wasm-examples/segment-anything/README.md @@ -1,6 +1,7 @@ ## Running Segment Anything Example -Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes. +Here, we provide an example showing how to run the Segment Anything model in the +browser. ### Vanilla JS and WebWorkers