Source code for ltfatpy.sigproc.groupthresh

# -*- coding: utf-8 -*-
# ######### COPYRIGHT #########
# Credits
# #######
#
# Copyright(c) 2015-2018
# ----------------------
#
# * `LabEx Archimède <http://labex-archimede.univ-amu.fr/>`_
# * `Laboratoire d'Informatique Fondamentale <http://www.lif.univ-mrs.fr/>`_
#   (now `Laboratoire d'Informatique et Systèmes <http://www.lis-lab.fr/>`_)
# * `Institut de Mathématiques de Marseille <http://www.i2m.univ-amu.fr/>`_
# * `Université d'Aix-Marseille <http://www.univ-amu.fr/>`_
#
# This software is a port from LTFAT 2.1.0 :
# Copyright (C) 2005-2018 Peter L. Soendergaard <peter@sonderport.dk>.
#
# Contributors
# ------------
#
# * Denis Arrivault <contact.dev_AT_lis-lab.fr>
# * Florent Jaillet <contact.dev_AT_lis-lab.fr>
#
# Description
# -----------
#
# ltfatpy is a partial Python port of the
# `Large Time/Frequency Analysis Toolbox <http://ltfat.sourceforge.net/>`_,
# a MATLAB®/Octave toolbox for working with time-frequency analysis and
# synthesis.
#
# Version
# -------
#
# * ltfatpy version = 1.0.16
# * LTFAT version = 2.1.0
#
# Licence
# -------
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# ######### COPYRIGHT #########


""" Module of group thresholding

Ported from ltfat_2.1.0/sigproc/groupthresh.m

.. moduleauthor:: Florent Jaillet
"""

from __future__ import print_function, division

import numpy as np

from ltfatpy.comp.assert_sigreshape_pre import assert_sigreshape_pre
from ltfatpy.comp.assert_sigreshape_post import assert_sigreshape_post
from ltfatpy.sigproc.thresh import thresh


[docs]def groupthresh(xi, lamb, dim=1, group_type='group', thresh_type='soft'): """Group thresholding - Usage: | ``xo = groupthresh(xi, lamb)`` | ``xo = groupthresh(xi, lamb, dim)`` | ``xo = groupthresh(xi, lamb, ...)`` - Input parameters: :param numpy.ndarray xi: Input array. ``xi`` must be a two-dimensional array, with dimension ``0`` labelling groups, and dimension ``1`` labelling members. This means that the groups are the row vectors of the input (the vectors along the dimension ``1``). :param float lamb: Threshold :param int dim: Dimension along which to choose the groups (default ``dim=1``) :param str group_type: Optional flag specifying the grouping behaviour (see possible values below) :param str thresh_type: Optional flag specifying the type of thresholding within each group (see the parameter ``thresh_type`` in the help of :func:`~ltfatpy.sigproc.thresh.thresh` for possible values, default ``thresh_type=soft``) - Output parameters: :returns: Array of the same shape as **xi** containing data from **xi** after group thresholding :rtype: numpy.ndarray ``groupthresh(xi, lamb)`` performs group thresholding on ``xi``, with threshold ``lamb``. ``groupthresh(xi, lamb, dim)`` chooses groups along dimension ``dim``. The parameter **group_type** can take the following values: =========== ======================================================= ``'group'`` Shrinks all coefficients within a given group according to the value of the :math:`l^2` norm of the group in comparison to the threshold ``lamb``. This is the default ``'elite'`` Shrinks all coefficients within a given group according to the value of the :math:`l^1` norm of the group in comparison to the threshold value ``lamb`` =========== ======================================================= .. seealso:: :func:`~ltfatpy.sigproc.thresh.thresh` - References: :cite:`Kowalski08sparsity,kowalski2009mixed,yu2008audio` """ # Note: This function doesn't support the handling of sparse matrices # available in the Octave version. Only full numpy arrays are supported in # input and output. if not isinstance(lamb, float): raise TypeError('lamb must be a float') # dim (the time or frequency selector) is handled by assert_sigreshape_pre xi, L, NbMembers, NbGroups, dim, permutedshape, order = \ assert_sigreshape_pre(xi, None, dim) # Dense case (this Python port doesn't handle the sparse matrix case) xo = np.zeros(xi.shape, dtype=xi.dtype) if group_type == 'group': groupnorm = np.sqrt(np.sum(np.abs(xi)**2., axis=0)) w = thresh(groupnorm, lamb, thresh_type=thresh_type)[0] / groupnorm # Clean w for NaN. NaN appears if the input has a group with norm # exactly 0. w[np.isnan(w)] = 0 xo = xi * w elif group_type == 'elite': for ii in range(NbGroups): y = np.sort(np.abs(xi[:, ii]))[::-1] rhs = np.cumsum(y) rhs = rhs * lamb / (1. + lamb*np.arange(1., NbMembers+1)) M_ii = np.nonzero(np.diff(np.sign(y-rhs)))[0] # Note: the test on M_ii in the Octave version of this function # is surprisingly written as find can only return non-zero values # or an empty array. Here an equivalent test is written using a # more explicit formulation. if M_ii.size != 0: tau_ii = float(lamb * np.linalg.norm(y[:M_ii[0]+1], 1) / (1. + lamb*(M_ii[0]+1))) else: tau_ii = 0. xo[:, ii] = thresh(xi[:, ii], tau_ii, thresh_type=thresh_type)[0] xo = assert_sigreshape_post(xo, dim, permutedshape, order) return xo