Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a semi-supervised (specifically a combination of supervised and weakly-supervised data) version of weak algorithms #268

Closed
wants to merge 9 commits into from
Closed
5 changes: 3 additions & 2 deletions metric_learn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .sdml import SDML, SDML_Supervised
from .nca import NCA
from .lfda import LFDA
from .rca import RCA, RCA_Supervised
from .rca import RCA, RCA_Supervised, RCA_SemiSupervised
from .mlkr import MLKR
from .mmc import MMC, MMC_Supervised

Expand All @@ -17,4 +17,5 @@
__all__ = ['Constraints', 'Covariance', 'ITML', 'ITML_Supervised',
'LMNN', 'LSML', 'LSML_Supervised', 'SDML',
'SDML_Supervised', 'NCA', 'LFDA', 'RCA', 'RCA_Supervised',
'MLKR', 'MMC', 'MMC_Supervised', '__version__']
'RCA_SemiSupervised', 'MLKR', 'MMC', 'MMC_Supervised',
'__version__']
90 changes: 89 additions & 1 deletion metric_learn/rca.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def fit(self, X, chunks):

Parameters
----------
data : (n x d) data matrix
X : (n x d) data matrix
Each row corresponds to a single instance
chunks : (n,) array of ints
When ``chunks[i] == -1``, point i doesn't belong to any chunklet.
Expand Down Expand Up @@ -242,3 +242,91 @@ def fit(self, X, y, random_state='deprecated'):
chunk_size=self.chunk_size,
random_state=self.random_state)
return RCA.fit(self, X, chunks)


class RCA_SemiSupervised(RCA):
"""Semi-Supervised version of Relevant Components Analysis (RCA)

`RCA_SemiSupervised` combines data in the form of chunks with
data in the form of labeled points that goes through the same
process as in `RCA_SemiSupervised`.

Parameters
----------
n_components : int or None, optional (default=None)
Dimensionality of reduced space (if None, defaults to dimension of X).

num_dims : Not used

.. deprecated:: 0.5.0
`num_dims` was deprecated in version 0.5.0 and will
be removed in 0.6.0. Use `n_components` instead.

num_chunks: int, optional

chunk_size: int, optional

preprocessor : array-like, shape=(n_samples, n_features) or callable
The preprocessor to call to get tuples from indices. If array-like,
tuples will be formed like this: X[indices].

random_state : int or numpy.RandomState or None, optional (default=None)
A pseudo random number generator object or a seed for it if int.
It is used to randomly sample constraints from labels.

Attributes
----------
components_ : `numpy.ndarray`, shape=(n_components, n_features)
The learned linear transformation ``L``.
"""

def __init__(self, num_dims='deprecated', n_components=None,
pca_comps='deprecated', num_chunks=100, chunk_size=2,
preprocessor=None, random_state=None):
"""Initialize the supervised version of `RCA`."""
RCA.__init__(self, num_dims=num_dims, n_components=n_components,
pca_comps=pca_comps, preprocessor=preprocessor)
self.num_chunks = num_chunks
self.chunk_size = chunk_size
self.random_state = random_state

def fit(self, X, y, X_u, chunks,
random_state='deprecated'):
"""Create constraints from labels and learn the RCA model.
Needs num_constraints specified in constructor.

Parameters
----------
X : (n x d) labeled data matrix
each row corresponds to a single instance
y : (n) data labels
X_u : (n x d) unlabeled data matrix
chunks : (n,) array of ints
When ``chunks[i] == -1``, point i doesn't belong to any chunklet.
When ``chunks[i] == j``, point i belongs to chunklet j.
random_state : Not used
.. deprecated:: 0.5.0
`random_state` in the `fit` function was deprecated in version 0.5.0
and will be removed in 0.6.0. Set `random_state` at initialization
instead (when instantiating a new `RCA_SemiSupervised` object).
"""
if random_state != 'deprecated':
warnings.warn('"random_state" parameter in the `fit` function is '
'deprecated. Set `random_state` at initialization '
'instead (when instantiating a new `RCA_SemiSupervised` '
'object).', DeprecationWarning)
else:
warnings.warn('As of v0.5.0, `RCA_SemiSupervised` now uses the '
'`random_state` given at initialization to sample '
'constraints, not the default `np.random` from the `fit` '
'method, since this argument is now deprecated. '
'This warning will disappear in v0.6.0.',
ChangedBehaviorWarning)
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
sup_chunks = Constraints(y).chunks(num_chunks=self.num_chunks,
chunk_size=self.chunk_size,
random_state=self.random_state)
X_tot = np.concatenate([X, X_u])
chunks_tot = np.concatenate([sup_chunks, chunks])

return RCA.fit(self, X_tot, chunks_tot)
15 changes: 13 additions & 2 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC,
LSML_Supervised, ITML_Supervised, SDML_Supervised,
RCA_Supervised, MMC_Supervised, SDML, RCA, ITML,
LSML)
LSML, RCA_SemiSupervised)
# Import this specially for testing.
from metric_learn.constraints import wrap_pairs
from metric_learn.constraints import wrap_pairs, Constraints
from metric_learn.lmnn import _sum_outer_products


Expand Down Expand Up @@ -1136,6 +1136,17 @@ def test_changed_behaviour_warning_random_state(self):
rca_supervised.fit(X, y)
assert any(msg == str(wrn.message) for wrn in raised_warning)

def test_semi_supervised(self):
n = 100
X, y = make_classification(random_state=42, n_samples=2 * n)
rca_semisupervised = RCA_SemiSupervised(num_chunks=20)
cons = Constraints(y[n:])
chunks = cons.chunks(num_chunks=20)
rca_semisupervised.fit(X[:n], y[:n],
X[n:], chunks)
rca_semisupervised.fit(X[:n], y[:n],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably add more tests around what rca_semisupervised looks like after fitting

X[n:], chunks, random_state=42)


@pytest.mark.parametrize('num_dims', [None, 2])
def test_deprecation_num_dims_rca(num_dims):
Expand Down