# nws_tools.py - Collection of network creation/processing/analysis/plotting routines
#
# Author: Stefan Fuertinger [stefan.fuertinger@esi-frankfurt.de]
# Created: December 22 2014
# Last modified: <2017-11-06 12:49:28>
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
import natsort
import os
import csv
import inspect
import fnmatch
try:
from scipy import weave
except:
import weave
from numpy.linalg import norm
from mpl_toolkits.mplot3d import Axes3D, proj3d
from matplotlib.patches import FancyArrowPatch, Circle
from matplotlib.colors import Normalize, colorConverter, LightSource
import math
from recipes import myglob
##########################################################################################
[docs]def strengths_und(CIJ):
"""
Compute nodal strengths in an undirected graph
Parameters
----------
CIJ : NumPy 2darray
Undirected binary/weighted connection matrix
Returns
-------
st : NumPy 1darray
Nodal strength vector
Notes
-----
This function does *not* do any error checking and assumes you know what you are doing
See also
--------
strengths_und.m : in the Brain Connectivity Toolbox (BCT) for MATLAB, currently available
`here <https://sites.google.com/site/bctnet/>`_
bctpy : An unofficial Python port of the BCT is currently available at the
`Python Package Index <https://pypi.python.org/pypi/bctpy>`_
and can be installed using `pip`.
"""
return np.sum(CIJ,axis=0)
##########################################################################################
[docs]def degrees_und(CIJ):
"""
Compute nodal degrees in an undirected graph
Parameters
----------
CIJ : NumPy 2darray
Undirected binary/weighted connection matrix
Returns
-------
deg : NumPy 1darray
Nodal degree vector
Notes
-----
This function does *not* do any error checking and assumes you know what you are doing
See also
--------
degrees_und.m : in the Brain Connectivity Toolbox (BCT) for MATLAB, currently available
`here <https://sites.google.com/site/bctnet/>`_
bctpy : An unofficial Python port of the BCT is currently available at the
`Python Package Index <https://pypi.python.org/pypi/bctpy>`_
and can be installed using `pip`.
"""
return (CIJ != 0).sum(1)
##########################################################################################
[docs]def density_und(CIJ):
"""
Compute the connection density of an undirected graph
Parameters
----------
CIJ : NumPy 2darray
Undirected binary/weighted connection matrix
Returns
-------
den : float
density (fraction of present connections to possible connections)
Notes
-----
This function does *not* do any error checking and assumes you know what you are doing
See also
--------
density_und.m : in the Brain Connectivity Toolbox (BCT) for MATLAB, currently available
`here <https://sites.google.com/site/bctnet/>`_
bctpy : An unofficial Python port of the BCT is currently available at the
`Python Package Index <https://pypi.python.org/pypi/bctpy>`_
and can be installed using `pip`.
"""
N = CIJ.shape[0] # no. of nodes
K = (np.triu(CIJ,1)!=0).sum() # no. of edges
return K/((N**2 - N)/2.0)
##########################################################################################
[docs]def get_corr(txtpath,corrtype='pearson',sublist=[],**kwargs):
"""
Compute pair-wise statistical dependence of time-series
Parameters
----------
txtpath : str
Path to directory holding ROI-averaged time-series dumped in `txt` files.
The following file-naming convention is required `sNxy_bla_bla.txt`,
where `N` is the group id (1,2,3,...), `xy` denotes the subject number
(01,02,...,99 or 001,002,...,999) and everything else is separated
by underscores. The files will be read in lexicographic order,
i.e., `s101_1.txt`, `s101_2.txt`,... or `s101_Amygdala.txt`, `s101_Beemygdala`,...
See Notes for more details.
corrtype : str
Specifier indicating which type of statistical dependence to use to compute
pairwise dependence. Currently supported options are
`pearson`: the classical zero-lag Pearson correlation coefficient
(see NumPy's `corrcoef` for details)
`mi`: (normalized) mutual information
(see the docstring of `mutual_info` in this module for details)
sublist : list or NumPy 1darray
List of subject codes to process, e.g., `sublist = ['s101','s102']`.
By default all subjects found in `txtpath` will be processed.
**kwargs : keyword arguments
Additional keyword arguments to be passed on to the function computing
the pairwise dependence (currently either NumPy's `corrcoef` or `mutual_info`
in this module).
Returns
-------
res : dict
Dictionary with fields:
corrs : NumPy 3darray
`N`-by-`N` matrices of pair-wise regional statistical dependencies
of `numsubs` subjects. Format is `corrs.shape = (N,N,numsubs)` such that
`corrs[:,:,i]` = `N x N` statistical dependence matrix of `i`-th subject
bigmat : NumPy 3darray
Tensor holding unprocessed time series of all subjects. Format is
`bigmat.shape = (tlen,N,numsubs)` where `tlen` is the maximum
time-series-length across all subjects (if time-series of different
lengths were used in the computation, any unfilled entries in `bigmat`
will be NumPy `nan`'s, see Notes for details) and `N` is the number of
regions (=nodes in the networks).
sublist : list of strings
List of processed subjects specified by `txtpath`, e.g.,
`sublist = ['s101','s103','s110','s111','s112',...]`
Notes
-----
Per-subject time-series do not necessarily have to be of the same length across
a subject cohort. However, all ROI-time-courses *within* the same subject must have
the same number of entries.
For instance, all ROI-time-courses in `s101` can have 140 entries, and time-series
of `s102` might have 130 entries. The remaining 10 values "missing" for `s102` are
filled with `NaN`'s in `bigmat`. However, if `s101_2.txt` contains 140 data-points while only
130 entries are found in `s101_3.txt`, the code will raise a `ValueError`.
See also
--------
corrcoef : Pearson product-moment correlation coefficents computed in NumPy
mutual_info : Compute (normalized) mutual information coefficients
"""
# Make sure `txtpath` doesn't contain nonsense and points to an existing location
if not isinstance(txtpath,(str,unicode)):
raise TypeError('Input has to be a string specifying the path to the txt-file directory!')
txtpath = str(txtpath)
if txtpath.find("~") == 0:
txtpath = os.path.expanduser('~') + txtpath[1:]
if not os.path.isdir(txtpath):
raise ValueError('Invalid directory: '+txtpath+'!')
# Check `corrtype`
if not isinstance(corrtype,(str,unicode)):
raise TypeError('Statistical dependence type input must be a string, not '+type(corrtype).__name__+'!')
if corrtype != 'mi' and corrtype != 'pearson':
raise ValueError("Currently, only Pearson and (N)MI supported!")
# Check `sublist`
if not isinstance(sublist,(list,np.ndarray)):
raise TypeError('Subject codes have to be provided as Python list/NumPy 1darray, not '+type(sublist).__name__+'!')
if len(np.array(sublist).shape) != 1:
raise ValueError("Subject codes have to be provided as 1-d list/array!")
# Get length of `sublist` (to see if a subject list was provided)
numsubs = len(sublist)
# Get list of all txt-files in `txtpath` and order them lexicographically
if txtpath[-1] == ' ' or txtpath[-1] == os.sep: txtpath = txtpath[:-1]
txtfiles = natsort.natsorted(myglob(txtpath,"s*.[Tt][Xx][Tt]"), key=lambda y: y.lower())
if len(txtfiles) < 2: raise ValueError('Found fewer than 2 text files in '+txtpath+'!')
# If no subject-list was provided, take first subject to get the number of ROIs to be processed
if numsubs == 0:
# Search from left in file-name for first "s" (naming scheme: sNxy_bla_bla_.txt)
firstsub = txtfiles[0]
firstsub = firstsub.replace(txtpath+os.sep,'')
s_in_name = firstsub.find('s')
# The characters right of "s" until the first "_" are the subject identifier
udrline = firstsub[s_in_name::].find('_')
subject = firstsub[s_in_name:s_in_name+udrline]
# Generate list of subjects
sublist = [subject]
for fl in txtfiles:
if fl.count(subject) == 0:
s_in_name = fl.rfind('s')
udrline = fl[s_in_name::].find('_')
subject = fl[s_in_name:s_in_name+udrline]
sublist.append(subject)
# Update `numsubs`
numsubs = len(sublist)
# Prepare output message
msg = "Found "
else:
# Just take the first entry of user-provided subject list
subject = sublist[0]
# Prepare output message
msg = "Processing "
# Talk to the user
print msg+str(numsubs)+" subjects: "+"".join(sb+", " for sb in sublist)[:-2]
# Check if the number of ROIs is consistent across subjects
nrois = np.zeros((numsubs,),dtype=int)
txtflstr = ''.join(txtfiles)
for ns, sub in enumerate(sublist):
nrois[ns] = txtflstr.count(sub+"_")
nroisu = np.unique(nrois).astype(int)
if nroisu.size > 1:
if nroisu.min() == 0:
bad_subs = ""
else:
bad_subs = "Found "
for nsu in nroisu:
if nsu == 0:
bad_subs += "No data found for Subject(s) "
else:
bad_subs += str(nsu)+" regions in Subject(s) "
bad_subs += "".join(sublist[idx]+", " for idx in np.where(nrois == nsu)[0])
msg = "Inconsisten number of time-series across subjects! "+bad_subs[:-2]
raise ValueError(msg)
else:
numregs = nroisu[0]
# Get (actual) number of subjects
numsubs = len(sublist)
# Scan files to find time-series length
tlens = np.zeros((numsubs,),dtype=int)
for k in xrange(numsubs):
roi = 0
for fl in txtfiles:
if fl.count(sublist[k]+"_"): # make sure we differentiate b/w "s1_*.txt" and "s10_.txt"...
try:
ts_vec = np.loadtxt(fl)
except:
raise ValueError("Cannot read file "+fl)
if roi == 0:
tlens[k] = ts_vec.size # Subject's first TS sets our reference length
if ts_vec.size != tlens[k]:
raise ValueError("Error reading file: "+fl+\
" Expected a time-series of length "+str(tlens[k])+", "+
"but actual length is "+str(ts_vec.size))
roi += 1
# Check the lengths of the detected time-series
if tlens.min() <= 2:
raise ValueError('Time-series of Subject '+sublist[tlens.argmin()]+' is empty or has fewer than 2 entries!')
# Allocate tensor to hold all time series
bigmat = np.zeros((tlens.max(),numregs,numsubs)) + np.nan
# Allocate tensor holding statistical dependence matrices of all subjects
corrs = np.zeros((numregs,numregs,numsubs))
# Ready to do this...
print "Extracting data and calculating "+corrtype.upper()+" coefficients"
# Cycle through subjects and save per-subject time series data column-wise
for k in xrange(numsubs):
col = 0
for fl in txtfiles:
if fl.count(sublist[k]+"_"):
ts_vec = np.loadtxt(fl)
bigmat[:tlens[k],col,k] = ts_vec
col += 1
# Compute statistical dependence based on corrtype
if corrtype == 'pearson':
corrs[:,:,k] = np.corrcoef(bigmat[:tlens[k],:,k],rowvar=0,**kwargs)
elif corrtype == 'mi':
corrs[:,:,k] = mutual_info(bigmat[:tlens[k],:,k],**kwargs)
# Happy breakdown
print "Done"
return {'corrs':corrs, 'bigmat':bigmat, 'sublist':sublist}
##########################################################################################
[docs]def corrcheck(*args,**kwargs):
"""
Sanity checks for statistical dependence matrices
Parameters
----------
Dynamic : Usage as follows
corrcheck(A) : input is NumPy 2darray
shows some statistics for the statistical dependence matrix `A`
corrcheck(A,label) : input is NumPy 2darray and `['string']`
shows some statistics for the matrix `A` and uses
`label`, a list containing one string, as title in figures.
corrcheck(A,B,C,...) : input are many NumPy 2darrays
shows some statistics for the statistical dependence matrices `A`, `B`, `C`,....
corrcheck(A,B,C,...,label) : input are many NumPy 2darrays and a list of strings
shows some statistics for the statistical dependence matrices `A`, `B`, `C`,....
and uses the list of strings `label` to generate titles in figures.
Note that `len(label)` has to be equal to the number of
input matrices.
corrcheck(T) : input is NumPy 3darray
shows some statistics for statistical dependence matrices stored
in the tensor `T`. The storage scheme has to be
`T[:,:,0] = A`
`T[:,:,1] = B`
`T[:,:,2] = C`
etc.
where `A`, `B`, `C`,... are matrices.
corrcheck(T,label) : input is NumPy 3darray and list of strings
shows some statistics for matrices stored
in the tensor `T`. The storage scheme has to be
`T[:,:,0] = A`
`T[:,:,1] = B`
`T[:,:,2] = C`
etc.
where `A`, `B`, `C`,... are matrices. The list of strings `label`
is used to generate titles in figures. Note that `len(label)`
has to be equal to `T.shape[2]`
corrcheck(...,title='mytitle') : input is any of the above
same as above and and uses the string `mytitle` as window name for figures.
Returns
-------
Nothing : None
Notes
-----
None
See also
--------
None
"""
# Plotting params used later (max. #plots per row)
cplot = 5
# Sanity checks
myin = len(args)
if myin == 0:
raise ValueError('At least one input required!')
# Assign global name for all figures if provided by additional keyword argument `title`
figtitle = kwargs.get('title',None);
nofigname = False
if figtitle is None:
nofigname = True
else:
if not isinstance(figtitle,(str,unicode)):
raise ValueError('Figure title must be a string!')
# If labels have been provided, extract them now
if isinstance(args[-1],(list)):
myin -= 1
labels = args[-1]
usrlbl = 1
elif isinstance(args[-1],(str,unicode)):
myin -= 1
labels = [args[-1]]
usrlbl = 1
else:
usrlbl = 0
# Try to get shape of input
if not isinstance(args[0],np.ndarray):
raise TypeError("Expected NumPy array(s) as input, found "+type(args[0]).__name__+"!")
szin = len(args[0].shape)
# If input is a list of matrices, store them in a tensor
if szin == 2:
rw,cl = args[0].shape
if (rw != cl) or (min(args[0].shape)==1):
raise ValueError('Input matrices must be square!')
corrs = np.zeros((rw,cl,myin))
for i in xrange(myin):
if not isinstance(args[i],np.ndarray):
raise TypeError("All but last input must be NumPy arrays!")
try:
corrs[:,:,i] = args[i]
except:
raise ValueError('All input matrices must be real and of the same size!')
# If input is a tensor, there's not much to do
elif szin == 3:
if myin > 1: raise ValueError('Not more than one input tensor supported!')
shv = args[0].shape
if (min(shv[0],shv[1]) == 1) or (shv[0]!=shv[1]):
raise ValueError('Input tensor must be of the format N-by-N-by-k!')
corrs = args[0]
else:
raise TypeError('Input has to be either a matrix/matrices or a rank1-tensor!')
# Count number of matrices and get their dimension
nmat = corrs.shape[-1]
N = corrs.shape[0]
# Check if those matrices are real and "reasonable"
if not np.issubdtype(corrs.dtype, np.number) or not np.isreal(corrs).all():
raise ValueError("Input arrays must be real-valued!")
if np.isfinite(corrs).min() == False:
raise ValueError("All matrices must be real without NaNs or Infs!")
# Check if we're dealing with Pearson or NMI matrices (or something completely unexpected)
cmin = corrs.min(); cmax = corrs.max()
if cmax > 1 or cmin < -1:
msg = "WARNING: Input has to have values between -1/+1 or 0/+1. Found "+str(cmin)+" to "+str(cmax)
print msg
maxval = 1
if corrs.min() < 0:
minval = -1
else:
minval = 0
# If labels have been provided, check if we got enough of'em; if there are no labels, generate defaults
if (usrlbl):
if len(labels) != nmat:
raise ValueError('Numbers of labels and matrices do not match up!')
for lb in labels:
if not isinstance(lb,(str,unicode)):
raise ValueError('Labels must be provided as list of strings or a single string!')
else:
labels = ['Matrix '+str(i+1) for i in xrange(nmat)]
# Set subplot params and turn on interactive plotting
rplot = int(np.ceil(nmat/cplot))
if nmat <= cplot: cplot = nmat
plt.ion()
# Now let's actually do something and plot the statistical dependence matrices (show warning matrix if is not symmetric)
fig = plt.figure(figsize=(8,8))
if nofigname: figtitle = fig.canvas.get_window_title()
fig.canvas.set_window_title(figtitle+': '+str(N)+' Nodes',)
for i in xrange(nmat):
plt.subplot(rplot,cplot,i+1)
im = plt.imshow(corrs[:,:,i],cmap='jet',interpolation='nearest',vmin=minval,vmax=maxval)
plt.axis('off')
plt.title(labels[i])
if issym(corrs[:,:,i]) == False:
print "WARNING: "+labels[i]+" is not symmetric!"
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
plt.draw()
# Plot statistical dependence histograms
meanval = np.mean([minval,maxval])
idx = np.nonzero(np.triu(np.ones((N,N)),1))
NN = (N**2 - N)/2
fig = plt.figure(figsize=(8,8))
if nofigname: figtitle = fig.canvas.get_window_title()
fig.canvas.set_window_title(figtitle+': '+"Statistical Dependence Histograms")
bars = []; ylims = []
for i in xrange(nmat):
cvec = corrs[idx[0],idx[1],i]
[corrcount,corrbins] = np.histogram(cvec,bins=20,range=(minval,maxval))
bars.append(plt.subplot(rplot,cplot,i+1))
plt.bar(corrbins[:-1],corrcount/NN,width=np.abs(corrbins[0]-corrbins[1]))
ylims.append(bars[-1].get_ylim()[1])
plt.xlim(minval,maxval)
plt.xticks((minval,meanval,maxval),(str(minval),str(meanval),str(maxval)))
plt.title(labels[i])
if np.mod(i+1,cplot) == 1: plt.ylabel('Frequency')
ymax = max(ylims)
for mybar in bars: mybar.set_ylim(top=ymax)
plt.draw()
# Show negative correlations (for Pearson matrices)
if minval < 0:
fig = plt.figure(figsize=(8,8))
if nofigname: figtitle = fig.canvas.get_window_title()
fig.canvas.set_window_title(figtitle+': '+"Negative Correlations Are BLACK")
for i in xrange(nmat):
plt.subplot(rplot,cplot,i+1)
plt.imshow((corrs[:,:,i]>=0).astype(float),cmap='gray',interpolation='nearest',vmin=0,vmax=1)
plt.axis('off')
plt.title(labels[i])
plt.draw()
# Diversity
fig = plt.figure(figsize=(8,8))
if nofigname: figtitle = fig.canvas.get_window_title()
fig.canvas.set_window_title(figtitle+': '+"Diversity of Statistical Dependencies")
xsteps = np.arange(1,N+1)
stems = []; ylims = []
for i in xrange(nmat):
stems.append(plt.subplot(rplot,cplot,i+1))
varc = np.var(corrs[:,:,i],0,ddof=1)
plt.stem(xsteps,varc)
ylims.append(stems[-1].get_ylim()[1])
plt.xlim(-1,N+1)
plt.xticks((0,N),('1',str(N)))
plt.title(labels[i])
ymax = max(ylims)
for mystem in stems: mystem.set_ylim(top=ymax)
plt.draw()
##########################################################################################
[docs]def get_meannw(nws,percval=0.0):
"""
Helper function to compute group-averaged networks
Parameters
----------
nws : NumPy 3darray
`N`-by-`N` connection matrices of `numsubs` subjects. Format is `nws.shape = (N,N,numsubs)`
such that `nws[:,:,i] = N x N` connection matrix of `i`-th subject
percval : float
Percentage value, such that connections not present in at least `percval`
percent of subjects are not considered, thus `0 <= percval <= 1`.
Default setting is `percval = 0.0`
Returns
-------
mean_wghted : NumPy 2darray
`N`-by-`N` mean value matrix of `numsubs` matrices stored in `nws` where
only connections present in at least `percval` percent of subjects
are considered
percval : float
Percentage value used to generate `mean_wghted`
Notes
-----
If the current setting of `percval` leads to a disconnected network,
the code increases `percval` in 5% steps to ensure connectedness of the group-averaged graph.
The concept of using only a certain percentage of edges present in subjects was taken from [1]_.
See also
--------
None
References
----------
.. [1] M. van den Heuvel, O. Sporns. Rich-Club Organization of the Human Connectome.
J. Neurosci, 31(44) 15775-15786, 2011.
"""
# Sanity checks
arrcheck(nws,'tensor','nws')
scalarcheck(percval,'percval',bounds=[0,1])
# Get shape of input tensor
N = nws.shape[0]
numsubs = nws.shape[-1]
# Remove self-connections
nws = rm_selfies(nws)
# Allocate memory for binary/weighted group averaged networks
mean_binary = np.zeros((N,N))
mean_wghted = np.zeros((N,N))
# Compute mean network and keep increasing `percval` until we get a connected mean network
docalc = True
while docalc:
# Reset matrices
mean_binary[:] = 0
mean_wghted[:] = 0
# Cycle through subjects to compute average network
for i in xrange(numsubs):
mean_binary = mean_binary + (nws[:,:,i]!=0).astype(float)
mean_wghted = mean_wghted + nws[:,:,i]
# Kick out connections not present in at least `percval%` of subjects (in binary and weighted NWs)
mean_binary = (mean_binary/numsubs > percval).astype(float)
mean_wghted = mean_wghted/numsubs * mean_binary
# Check connectedness of mean network
if degrees_und(mean_binary).min() == 0:
print "WARNING: Mean network disconnected for `percval` = "+str(np.round(1e2*percval))+"%"
if percval > 0:
print "Decreasing `percval` by 5%..."
percval -= 0.05
print "New value for `percval` is now "+str(np.round(1e2*percval))+"%"
else:
msg = "Mean network disconnected for `percval` = 0%. That means at least one node is "+\
"disconnected in ALL per-subject networks..."
raise ValueError(msg)
else:
docalc = False
return mean_wghted, percval
##########################################################################################
[docs]def rm_negatives(corrs):
"""
Remove negative entries from connection matrices
Parameters
----------
corrs : NumPy 3darray
An array of `K` matrices of dimension `N`-by-`N`. Format is `corrs.shape = (N,N,K)`,
such that `corrs[:,:,i]` is the `i`-th `N x N` matrix
Returns
-------
nws : NumPy 3darray
Same format as input tensor but `corrs >= 0`.
Notes
-----
None
See also
--------
None
"""
# Sanity checks
arrcheck(corrs,'tensor','corrs')
# See how many matrices are stacked in the array
K = corrs.shape[-1]
# Zero diagonals of matrices
for i in xrange(K):
np.fill_diagonal(corrs[:,:,i],0)
# Remove negative entries
nws = (corrs > 0)*corrs
# Check if we lost some nodes...
ndnum = str(corrs.shape[0])
for i in xrange(K):
deg = degrees_und(corrs[:,:,i])
if deg.min() == 0:
badidx = np.nonzero(deg==deg.min())[0]
print "WARNING: In network "+str(i)+" a total of "+str(badidx.size)+" out of "+ndnum+\
" node(s) got disconnected, namely vertices #"+str(badidx)
return nws
##########################################################################################
[docs]def rm_selfies(conns):
"""
Remove self-connections from connection matrices
Parameters
----------
conns : NumPy 3darray
An array of `K` connection matrices of dimension `N`-by-`N`. Format is `conns.shape = (N,N,K)`,
such that `conns[:,:,i]` is the `i`-th `N x N` connection matrix
Returns
-------
nws : NumPy 3darray
Same format as input array but `np.diag(conns[:,:,k]).min() = 0.0`.
Notes
-----
None
See also
--------
None
"""
# Sanity checks
arrcheck(conns,'tensor','conns')
# Create output quantity and zero its diagonals
nws = conns.copy()
for i in xrange(nws.shape[-1]):
np.fill_diagonal(nws[:,:,i],0)
return nws
##########################################################################################
[docs]def thresh_nws(nws,userdens=None,percval=0.0,force_den=False,span_tree=False):
"""
Threshold networks based on connection density
Parameters
----------
nws : NumPy 3darray
Undirected `N`-by-`N` (un)weighted connection matrices of `numsubs` subjects.
Format is `corrs.shape = (N,N,numsubs)` such that `corrs[:,:,i] = N x N`
connection matrix of `i`-th subject
userdens : int
By default, the input networks are thresholded down to the lowest common
connection density without disconnecting any nodes in the networks using
a relative thresholding strategy (`force_den = False` and `span_tree = False`).
If `userdens` is provided and `span_tree = False`, then `userdens`
is used as target density in the relative thresholding strategy. However,
if `userdens` is below the minimum density before networks fragment,
it will not be used unless `force_den = True`.
If `span_tree = True` and `userdens` is `None`, then maximum spanning
trees will be returned for all input networks. If `userdens` is provided,
the spanning trees will be populated with the strongest connections
found in the original networks up to the desired edge density.
For both relative thresholding and maximum spanning tree density reduction,
`userdens` should be either `None` or an integer between 0 and 100.
See Notes below for more details.
percval : float
Percentage value for computing mean network averaged across all thresholded
graphs, such that connections not present in at least `percval`
percent of subjects are not considered (`0 <= percval <= 1`).
Default setting is `percval = 0.0`. See `get_meannw` for details.
force_den : bool
If `force_den = True` relative thresholding is applied to the networks
until all graphs hit the desired density level defined by the user
even if nodes get disconnected in the process. This argument has no
effect if `span_tree = True`. By default, `force_den = False`.
span_tree : bool
If `span_tree` is `True` density reduction is performed by constructing maximum
spanning trees. If `userdens` is `None`, only spanning trees for all input networks
will be returned. If `userdens` is provided, spanning trees will be populated
with the strongest connections found in the original networks up to the
desired edge density. Note that `foce_den` is ignored if `span_tree` is `True`.
Returns
-------
Dictionary holding computed quantities. The fields of the dictionary depend upon
the values of the optional keyword arguments `userdens` and `span_tree`.
res : dict
Dictionary with fields
th_nws : NumPy 3darray
Sparse networks. Format is the same as for `nws`
(Not returned if `userdens` is `None` and `span_tree = True`).
den_values : NumPy 1darray
Density values of the networks stored in `th_nws`, such that `den_values[i]`
is the edge density of the graph `th_nws[:,:,i]`
(not returned if `userdens` is `None` and `span_tree = True`).
th_mnw : NumPy 2darray
Mean network averaged across all sparse networks `th_nws`
(not returned if `userdens` is `None` and `span_tree = True`).
mnw_percval: float
Percentage value used to compute `th_mnw` (see documentation of `get_meannw` for
details, not returned if `userdens` is `None` and `span_tree = True`).
tau_levels : NumPy 1darray
Cutoff values used in the relative thresholding strategy to compute
`th_nws`, i.e., `tau_levels[i]` is the threshold that generated
network `th_nws[:,:,i]` (only returned if `span_tree = False`).
nws_forest : NumPy 3darray
Maximum spanning trees calculated for all input networks
(only returned if `span_tree = True`).
mean_tree : NumPy 2darray
Mean spanning tree averaged across all spanning trees stored in
`nws_forest` (only returned if `span_tree = True`).
mtree_percval : float
Percentage value used to compute `mean_tree` (see documentation of `get_meannw` for
details, only returned if `span_tree = True`).
Notes
-----
This routine uses either a relative thresholding strategy or a maximum spanning tree
approach to decrease the density of a given set of input networks.
During relative thresholding (`span_tree = False`) edges are discarded based on their value relative to the
maximum edge weight found across all networks beginning with the weakest links. By default,
the thresholding algorithm uses the lowest common connection density across all input networks
before a node is disconnected as target edge density. That means, if networks `A`, `B` and `C`
can be thresholded down to 40%, 50% and 60% density, respectively, without disconnecting any
nodes, then the lowest common density for thresholding `A`, `B` and `C` together is 60%.
In this case the raw network `A` already has a density of 60% or lower, which is thus excluded
from thresholding and the original network is copied into `th_nws`. If a density level
is provided by the user, then the code tries to use it unless it violates connectedness
of all thresholded networks - in this case the lowest common density of all networks is used,
unless `force_den = True` which causes the code to employ the user-provided density level
for thresholding, disconnecting nodes from the networks in the process.
The maximum spanning tree approach (`span_tree = True`) can be interpreted as the inverse of relative
thresholding. Instead of chipping away weak edges in the input networks until a target density
is met (or nodes disconnect), a minimal backbone of the network is calculated and then
populated with the strongest connections found in the original network until a desired
edge density level is reached. The backbone of the network is calculated by computing the graph's maximum
spanning tree, that connects all nodes with the minimum number of maximum-weight edges.
Note, that unless each edge has a distinct unique weight value a graph has numerous different
maximum spanning trees. Thus, the spanning trees computed by this routine are usually *not* unique,
and consequently the thresholded networks may not be unique either (particularly for low
density levels, for which the computed populated networks are very similar to the underlying spanning trees).
Thus, in contrast to the more common relative thresholding strategy, this bottom-up approach
allows to reduce a given network's density to an almost arbitrary level
(>= density of the maximum spanning tree) without disconnecting nodes. However, unlike relative
thresholding, the computed sparse networks are not necessarily unique and strongly depend
on the intial maximum spanning tree. Note that if `userdens` is `None`, only maximum spanning
trees will be computed.
The code below relies on the routine `get_meannw` in this module to compute the group-averaged
network. Futher, maximum spanning trees are calculated using `backbone_wu.m` from the
Brain Connectivity Toolbox (BCT) for MATLAB via Octave. Thus, it requires Octave to be installed
with the BCT in its search path. Further, `oct2py` is needed to launch an Octave instance
from within Python.
See also
--------
get_meannw : Helper function to compute group-averaged networks
backbone_wu : in the Brain Connectivity Toolbox (BCT) for MATLAB, currently available
`here <https://sites.google.com/site/bctnet/>`_
"""
# Sanity checks
arrcheck(nws,'tensor','nws')
if userdens is not None:
scalarcheck(userdens,'userdens',kind='int',bounds=[0,100])
scalarcheck(percval,'percval',bounds=[0,1])
if not isinstance(force_den,bool):
raise TypeError("The optional argument `force_den` has to be Boolean!")
if not isinstance(span_tree,bool):
raise TypeError("The optional argument `span_tree` has to be Boolean!")
if force_den and span_tree:
print "\nWARNING: The flag `foce_den` has no effect if `span_tree == True`!"
# Try to import `octave` from `oct2py`
if span_tree:
try:
from oct2py import octave
except:
errmsg = "Could not import octave from oct2py! "+\
"To compute the maximum spanning tree octave must be installed and in the search path. "+\
"Furthermore, the Brain Connectivity Toolbox (BCT) for MATLAB must be installed "+\
"in the octave search path. "
raise ImportError(errmsg)
# Get dimension of per-subject networks
N = nws.shape[0]
numsubs = nws.shape[-1]
# Zero diagonals and check for symmetry
for i in xrange(numsubs):
np.fill_diagonal(nws[:,:,i],0)
if issym(nws[:,:,i]) == False:
raise ValueError("Matrix "+str(i)+" is not symmetric!")
# Get max. and min. weights (min weight should be >= 0 otherwise the stuff below makes no sense...)
maxw = nws.max()
if nws.min() < 0:
raise ValueError('Only non-negative weights supported!')
# Allocate vector for original densities
raw_den = np.zeros((numsubs,))
# Compute densities of raw networks
for i in xrange(numsubs):
raw_den[i] = density_und(nws[:,:,i])
# Compute min/max density in raw data
min_raw = int(np.floor(1e2*raw_den.min()))
max_raw = int(np.ceil(1e2*raw_den.max()))
# Break if a nw has density zero or if max. density is below desired dens.
if min_raw == 0:
raise ValueError('Network '+str(raw_den.argmin())+' has density 0%!')
if userdens >= max_raw:
print "All networks have density lower than desired density "+str(userdens)+"%"
th_mnw,mnw_percval = get_meannw(nws,percval)
res_dict = {'th_nws':nws, 'den_values': raw_den, \
'th_mnw': th_mnw, 'mnw_percval': mnw_percval}
# The structure of `backbone_wu.m` requires *exact* symmetry...
if span_tree:
nws_forest = np.zeros(nws.shape)
for i in xrange(numsubs):
mnw = nws[:,:,i].squeeze()
mnw = np.triu(mnw,1)
nws_forest[:,:,i] = octave.backbone_wu(mnw + mnw.T,2)
mean_tree, mtree_percval = get_meannw(nws_forest,percval)
res_dict['nws_forest'] = nws_forest
res_dict['mean_tree'] = mean_tree
res_dict['mtree_percval'] = mtree_percval
else:
res_dict['tau_levels'] = None
return res_dict
# Inform user about minimal/maximal density in raw data
print "\nRaw data has following density values: \n"
print "\tMinimal density: "+str(min_raw)+"%"
print "\tMaximal density: "+str(max_raw)+"%"
# Allocate space for output (needed for both regular thresholding and de-foresting)
th_nws = np.zeros(nws.shape)
den_values = np.zeros((numsubs,))
th_mnw = np.zeros((N,N))
# Maximum spanning tree shenanigans
if span_tree:
# Allocate space for the spanning trees
nws_forest = np.zeros(nws.shape)
# If no target density was provided, just compute trees and get out of here
if userdens is None:
print "\nCalculating maximum spanning trees..."
for i in xrange(numsubs):
mnw = nws[:,:,i].squeeze()
mnw = np.triu(mnw,1)
nws_forest[:,:,i] = octave.backbone_wu(mnw + mnw.T,2)
mean_tree,mtree_percval = get_meannw(nws_forest,percval)
return {'nws_forest': nws_forest, 'mean_tree': mean_tree, 'mtree_percval': mtree_percval}
else:
# The edge density `d` of an undirected network is given by
# (1) `d = 2*K/(N**2 - N)`,
# where `K` denotes the number of edges in the network. Thus, `K` can be approximated by
# (2) `N*avdg/2`,
# with `avdg` denoting the average nodal degree in the graph (divide by two
# to not count links twice (we have undirected links i <-> j, not i -> j and j <- i).
# Thus, substituting (2) for `K` in (1) and re-arranging terms yields
# `avdg = d*(N - 1)`. Thus, for a user-provided density value, we can compute
# the associated average degree of the wanted target network as
# avdg = np.round(userdens/100*(N**2 - N)/N)
avdg = np.round(userdens/100*(N - 1))
print "\nReducing network densities to "+str(userdens)+"% by inversely populating maximum spanning trees..."
# Use this average degree value to cut down input networks to desired density
for i in xrange(numsubs):
mnw = nws[:,:,i].squeeze()
mnw = np.triu(mnw,1)
raw_dper = int(np.round(1e2*raw_den[i]))
if raw_dper <= userdens:
print "Density of raw network #"+str(i+1)+" is "+str(raw_dper)+"%"+\
" which is already lower than thresholding density of "+str(userdens)+"%"
print "Returning original unthresholded network"
th_nws[:,:,i] = nws[:,:,i].copy()
den_values[i] = raw_den[i]
nws_forest[:,:,i] = octave.backbone_wu(mnw + mnw.T.squeeze(),2)
else:
nws_forest[:,:,i], th_nws[:,:,i] = octave.backbone_wu(mnw + mnw.T, avdg, nout=2)
den_values[i] = density_und(th_nws[:,:,i])
mean_tree,mtree_percval = get_meannw(nws_forest,percval)
# Populate results dictionary with method-specific quantities
res_dict = {'nws_forest': nws_forest, 'mean_tree': mean_tree, 'mtree_percval': mtree_percval}
# Here the good ol' relative weight thresholding
else:
# Allocate space for thresholds and thresholding stepsize
tau_levels = np.zeros((numsubs,))
dt = 1e-3
# Compute minimal admissible density per network
for i in xrange(numsubs):
mnw = nws[:,:,i]
tau = mnw.max(axis=0).min()
mnw = mnw*(mnw >= tau)
th_nws[:,:,i] = mnw.copy()
den_values[i] = density_und(mnw)
tau_levels[i] = tau - dt
# Compute minimal density before fragmentation across all subjects
densities = np.round(1e2*den_values)
print "\nMinimal admissible densities of per-subject networks are as follows: "
for i in xrange(densities.size): print "Subject #"+str(i+1)+": "+str(int(densities[i]))+"%"
min_den = int(np.round(1e2*den_values.max()))
print "\nThus, minimal density before fragmentation across all subjects is "+str(min_den)+"%"
# Assign thresholding density level
if userdens is None:
thresh_dens = min_den
else:
if userdens < min_den and force_den == False:
print "\nUser provided density of "+str(int(userdens))+\
"% lower than minimal admissible density of "+str(min_den)+"%. "
print "Using minimal admissible density instead. "
thresh_dens = min_den
elif userdens < min_den and force_den == True:
print "\nWARNING: Provided density of "+str(int(userdens))+\
"% leads to disconnected networks - proceed with caution..."
thresh_dens = int(userdens)
else:
thresh_dens = int(userdens)
# Inform the user about what's gonna happen
print "\nUsing density of "+str(int(thresh_dens))+"%. Starting thresholding procedure...\n"
# Backtracking parameter
beta = 0.3
# Cycle through subjects
for i in xrange(numsubs):
den_perc = 100
th = -dt
mnw = nws[:,:,i]
raw_dper = int(np.round(1e2*raw_den[i]))
if raw_dper <= thresh_dens:
print "Density of raw network #"+str(i)+" is "+str(raw_dper)+"%"+\
" which is already lower than thresholding density of "+str(thresh_dens)+"%"
print "Returning original unthresholded network"
th_nws[:,:,i] = mnw
tau_levels[i] = 0
den_values[i] = raw_den[i]
else:
while den_perc > thresh_dens:
th += dt
tau = th*maxw
mnw = mnw*(mnw >= tau).astype(float)
den = density_und(mnw)
den_perc = np.round(1e2*den)
if den_perc < thresh_dens:
th *= beta
th_nws[:,:,i] = mnw
tau_levels[i] = tau
den_values[i] = den
# Populate results dictionary with method-specific quantities
res_dict = {'tau_levels': tau_levels}
# Compute group average network
th_mnw,mnw_percval = get_meannw(th_nws,percval)
# Fill up results dictionary
res_dict['th_nws'] = th_nws
res_dict['den_values'] = den_values
res_dict['th_mnw'] = th_mnw
res_dict['mnw_percval'] = mnw_percval
# Be polite and dismiss the user
print "\nDone...\n"
return res_dict
##########################################################################################
[docs]def normalize(arr,vmin=0,vmax=1):
"""
Re-scales a NumPy ndarray
Parameters
----------
arr : NumPy ndarray
An array of size > 1 (shape can be arbitrary)
vmin : float
Floating point number representing the lower normalization bound.
(Note that it has to hold that `vmin < vmax`)
vmin : float
Floating point number representing the upper normalization bound.
(Note that it has to hold that `vmin < vmax`)
Returns
-------
arrn : NumPy ndarray
Scaled version of the input array `arr`, such that `arrn.min() == vmin` and
`arrn.max() == vmax`
Notes
-----
In contrast to Matplotlib's `Normalize`, *all* values of the input array are re-scaled,
even if outside the specified bounds. For instance, if `arr.min() == -1` and `arr.max() == 0.5` then
calling normalize with bounds `vmin = 0` and `vmax = 1` will result in an array `arrn`
satisfying `arrn.min() == 0` and `arrn.max() == 1`.
Examples
--------
>>> arr = array([[-1,.2],[100,0]])
>>> arrn = normalize(arr,vmin=-10,vmax=12)
>>> arrn
array([[-10. , -9.73861386],
[ 12. , -10. ]])
See also
--------
None
"""
# Ensure that `arr` is a NumPy-ndarray
if not isinstance(arr,np.ndarray):
raise TypeError('Input `arr` has to be a NumPy ndarray!')
if arr.size == 1:
raise ValueError('Input `arr` has to be a NumPy ndarray of size > 1!')
if not np.issubdtype(arr.dtype, np.number) or not np.isreal(arr).all():
raise ValueError("Input array hast to be real-valued!")
if np.isfinite(arr).min() == False:
raise ValueError("Input `arr` must be real-valued without Inf's or NaN's!")
# If normalization bounds are user specified, check them
scalarcheck(vmin,'vmin')
scalarcheck(vmax,'vmax')
if vmax <= vmin:
raise ValueError('Lower bound `vmin` has to be strictly smaller than upper bound `vmax`!')
if np.absolute(vmin - vmax) < 2*np.finfo(float).eps:
raise ValueError('Bounds too close: `|vmin - vmax| < eps`, no normalization possible')
# Get min and max of array
arrmin = arr.min()
arrmax = arr.max()
# If min and max values of array are identical do nothing, if they differ close to machine precision, abort
if arrmin == arrmax:
return arr
elif np.absolute(arrmin - arrmax) <= np.finfo(float).eps:
raise ValueError('Minimal and maximal values of array too close, no normalization possible')
# Return normalized array
return (arr - arrmin)*(vmax - vmin)/(arrmax - arrmin) + vmin
##########################################################################################
[docs]def csv2dict(csvfile):
"""
Reads 3D nodal coordinates of from a csv file into a Python dictionary
Parameters
----------
csvfile : str
File-name of (or full path to) the csv file holding nodal coordinates.
The format of this file HAS to be
`x, y, z`
`x, y, z`
`x, y, z`
.
.
for each node. Thus `#rows = #nodes`.
Returns
-------
mydict : dict
Nodal coordinates as read from the input csv file. Format is
`{0: (x, y, z),`
`{1: (x, y, z),`
`{2: (x, y, z),`
.
.
Thus the dictionary has `#nodes` keys.
Notes
-----
None
See also
--------
None
"""
# Make sure `csvfile` makes sense
if not isinstance(csvfile,(str,unicode)):
raise TypeError("Name of csv-file has to be a string!")
if csvfile.find("~") == 0:
csvfile = os.path.expanduser('~') + csvfile[1:]
if not os.path.isfile(csvfile):
raise ValueError('File: `'+csvfile+'` does not exist!')
# Open `csvfile`
fh = open(csvfile,'rU')
fh.seek(0)
# Read nodal coordinates
reader = csv.reader(fh, dialect='excel',delimiter=',', quotechar='"')
# Iterate over rows and convert coordinates from string lists to float tuples
mydict = {}
i = 0
for i, row in enumerate(reader):
try:
mydict[i] = tuple([float(r) for r in row])
except ValueError as ve:
raise ValueError("Error reading file `"+str(csvfile)+"` on line "+str(i+1)+": "+ve.message)
return mydict
##########################################################################################
[docs]def shownet(A,coords,colorvec=None,sizevec=None,labels=None,threshs=[.8,.3,0],lwdths=[5,2,.1],nodecmap='jet',edgecmap='jet',textscale=3):
"""
Plots a network in 3D using Mayavi
Parameters
----------
A : NumPy 2darray
Square `N`-by-`N` connection matrix of the network
coords: dict
Nodal coordinates of the graph. Format is
`{0: (x, y, z),`
`{1: (x, y, z),`
`{2: (x, y, z),`
.
.
Note that the dictionary has to have `N` keys.
colorvec : NumPy 1darray
Vector of color-values for each node. This could be nodal strength or modular information of nodes
(i.e., to which module does node `i` belong to). Thus `colorvec` has to be of length `N` and all its
components must be in `[0,1]`.
sizevec : NumPy 1darray
Vector of nodal sizes. This could be degree, centrality, etc. Thus `sizevec` has to be of length
`N` and all its components must be `>= 0`.
labels : list or NumPy 1darray
Nodal labels. Format is `['Name1','Name2','Name3',...]` where the ordering HAS to be the same
as in the `coords` dictionary. Note that the list/array has to have length `N`.
threshs : list or NumPy 1darray
Thresholds for visualization. Edges with weights larger than `threshs[0]` are drawn
thickest, weights `> threshs[1]` are thinner and so on. Note that if `threshs[-1]>0` not all
edges of the network are plotted (since edges with `0 < weight < threshs[-1]` will be ignored).
lwdths : list or NumPy 1darray
Line-widths associated to the thresholds provided by `threshs`. Edges with weights larger than
`threshs[0]` are drawn with line-width `lwdths[0]`, edges with `weights > threshs[1]`
have line-width `lwdths[1]` and so on. Thus `len(lwdths) == len(threshs)`.
nodecmap : str
Mayavi colormap to be used for plotting nodes. See Notes for details.
edgecmap : str
Mayavi colormap to be used for plotting edges. See Notes for details.
textscale : float
Scaling factor for labels (larger numbers -> larger text)
Returns
-------
Nothing : None
Notes
-----
A list of available colormaps in Mayavi is currently available
`here <http://docs.enthought.com/mayavi/mayavi/mlab_changing_object_looks.html>`_.
See the
`Mayavi documentation <http://docs.enthought.com/mayavi/mayavi/auto/mlab_helper_functions.html>`_
for more info.
See also
--------
show_nw : A Matplotlib based implementation with extended functionality (but MUCH slower rendering)
"""
# For those lucky enough to have a running installation of Mayavi...
try:
from mayavi import mlab
except:
msg = 'Mayavi could not be imported. You might want to try `show_nw`, a slower (but more feature '\
+'rich) graph rendering routine based on Matplotlib.'
raise ImportError(msg)
# Make sure the adjacency/weighting matrix makes sense
arrcheck(A,'matrix','A')
(N,M) = A.shape
# Check the coordinate dictionary
try:
bad = (coords.keys() != N)
except:
raise TypeError("The coordinates have to be provided as dictionary!")
if bad:
raise ValueError('The coordinate dictionary has to have N keys!')
for val in coords.values():
if not isinstance(val,(list,np.ndarray)):
raise TypeError('All elements of the `coords` dictionary have to be lists/arrays!')
arrcheck(np.array(val),'vector','coordinates')
if len(val) != 3:
raise ValueError('All elements of the coords dictionary have to be 3-dimensional!')
# Check `colorvec` if provided, otherwise assign default value
if colorvec is not None:
arrcheck(colorvec,'vector','colorvec',bounds=[0,1])
if colorvec.size != N:
raise ValueError('`colorvec` has to have length `N`!')
else:
colorvec = np.ones((N,))
# Same for `sizevec`
if sizevec is not None:
arrcheck(sizevec,'vector','sizevec',bounds=[0,np.inf])
if sizevec.size != N:
raise ValueError('`sizevec` has to have length `N`!')
else:
sizevec = np.ones((N,))
# Check labels (if any provided)
if labels is not None:
try:
bad = (len(labels) != N)
except:
raise TypeError("Nodal labels have to be provided as list/NumPy 1darray!")
if bad:
raise ValueError("Number of nodes and labels does not match up!")
for lb in labels:
if not isinstance(lb,(str,unicode)):
raise ValueError('Each individual label has to be a string type!')
else:
labels = []
# Check thresholds and linewidhts
if not isinstance(threshs,(list,np.ndarray)):
raise TypeError("Visualization thresholds have to be provided as list/NumPy 1darray!")
threshs = np.array(threshs)
arrcheck(threshs,'vector','threshs')
n = threshs.size
if not isinstance(lwdths,(list,np.ndarray)):
raise TypeError("Linewidths have to be provided as list/NumPy 1darray!")
lwdths = np.array(lwdths)
arrcheck(lwdths,'vector','lwdths')
m = lwdths.size
if m != n:
raise ValueError("Number of thresholds and linewidths does not match up!")
# Make sure colormap definitions were given as strings
if not isinstance(nodecmap,(str,unicode)):
raise TypeError("Colormap for nodes has to be provided as string!")
if not isinstance(edgecmap,(str,unicode)):
raise TypeError("Colormap for edges has to be provided as string!")
# Check `textscale`
scalarcheck(textscale,'textscale')
# Now start to actually do something...
pts = mlab.quiver3d(np.array([coords[i][0] for i in coords.keys()]),\
np.array([coords[i][1] for i in coords.keys()]),\
np.array([coords[i][2] for i in coords.keys()]),\
sizevec,sizevec,sizevec,scalars=colorvec,\
scale_factor = 1,mode='sphere',colormap=nodecmap)
# Coloring of the balls is based on the provided scalars
pts.glyph.color_mode = 'color_by_scalar'
# Finally, center the glyphs on the data point
pts.glyph.glyph_source.glyph_source.center = [0, 0, 0]
# Cycle through threshold levels to generate different line-widths of networks
srcs = []; lines = []
for k in xrange(len(threshs)):
# Generate empty lists to hold (x,y,z) data and color information
x = list()
y = list()
z = list()
s = list()
connections = list()
index = 0
b = 2
# Get matrix entries > current threshold level
for i in xrange(N):
for j in xrange(i+1,N):
if A[i,j] > threshs[k]:
x.append(coords[i][0])
x.append(coords[j][0])
y.append(coords[i][1])
y.append(coords[j][1])
z.append(coords[i][2])
z.append(coords[j][2])
s.append(A[i][j])
s.append(A[i][j])
connections.append(np.vstack([np.arange(index, index + b - 1.5), np.arange(index+1, index + b - 0.5)]).T)
index += b
# Finally generate lines connecting dots
srcs.append(mlab.pipeline.scalar_scatter(x,y,z,s))
srcs[-1].mlab_source.dataset.lines = connections
lines.append(mlab.pipeline.stripper(srcs[-1]))
mlab.pipeline.surface(lines[-1], colormap=edgecmap, line_width=lwdths[k], vmax=1, vmin=0)
# Label nodes if wanted
for i in xrange(len(labels)):
mlab.text3d(coords[i][0]+2,coords[i][1],coords[i][2],labels[i],color=(0,0,0),scale=textscale)
return
##########################################################################################
[docs]def show_nw(A,coords,colorvec=None,sizevec=None,labels=None,nodecmap=plt.get_cmap(name='jet'),edgecmap=plt.get_cmap(name='jet'),linewidths=None,nodes3d=False,viewtype='axial'):
"""
Matplotlib-based plotting routine for networks
Parameters
----------
A : NumPy 2darray
Square `N`-by-`N` connection matrix of the network
coords: dict
Nodal coordinates of the graph. Format is
`{0: (x, y, z),`
`{1: (x, y, z),`
`{2: (x, y, z),`
.
.
Note that the dictionary has to have `N` keys.
colorvec : NumPy 1darray
Vector of color-values for each node. This could be nodal strength or modular information of nodes
(i.e., to which module does node i belong to). Thus `colorvec` has to be of length `N` and all its
components must be in `[0,1]`.
sizevec : NumPy 1darray
Vector of nodal sizes. This could be degree, centrality, etc. Thus `sizevec` has to be of
length `N` and all its components must be `>= 0`.
labels : list or NumPy 1darray
Nodal labels. Format is `['Name1','Name2','Name3',...]` where the ordering HAS to be the same
as in the `coords` dictionary. Note that the list/array has to have length `N`.
nodecmap : Matplotlib colormap
Colormap to use for plotting nodes
edgecmap : Matplotlib colormap
Colormap to use for plotting edges
linewidths : NumPy 2darray
Same format and nonzero-pattern as `A`. If no linewidhts are provided then the edge connecting
nodes `v_i` and `v_j` is plotted using the linewidth `A[i,j]`. By specifying, e.g.,
`linewidhts = (1+A)**2`, the thickness of edges in the network-plot can be scaled.
nodes3d : bool
If `nodes3d=True` then nodes are plotted using 3d spheres in space (with `diameters = sizevec`).
If `nodes3d=False` then the Matplotlib `scatter` function is used to plot nodes as flat
2d disks (faster).
viewtype : str
Camera position, `viewtype` can be one of the following
`axial (= axial_t)` : Axial view from top down
`axial_t` : Axial view from top down
`axial_b` : Axial view from bottom up
`sagittal (= sagittal_l)` : Sagittal view from left
`sagittal_l` : Sagittal view from left
`sagittal_r` : Sagittal view from right
`coronal (= coronal_f)` : Coronal view from front
`coronal_f` : Coronal view from front
`coronal_b` : Coronal view from back
Returns
-------
Nothing : None
Notes
-----
See Matplotlib's `mplot3d tutorial <http://matplotlib.org/mpl_toolkits/mplot3d/tutorial.html>`_
See also
--------
shownet : A Mayavi based implementation with less functionality but MUCH faster rendering
"""
# Check the graph's connection matrix
arrcheck(A,'matrix','A')
(N,M) = A.shape
# Check the coordinate dictionary
try:
bad = (coords.keys() != N)
except:
raise TypeError("The coordinates have to be provided as dictionary!")
if bad:
raise ValueError('The coordinate dictionary has to have N keys!')
for val in coords.values():
if not isinstance(val,(list,np.ndarray)):
raise TypeError('All elements of the coords dictionary have to be lists/arrays!')
arrcheck(np.array(val),'vector','coordinates')
if len(val) != 3:
raise ValueError('All elements of the coords dictionary have to be 3-dimensional!')
# Check `colorvec` if provided, otherwise assign default value
if colorvec is not None:
arrcheck(colorvec,'vector','colorvec',bounds=[0,1])
if colorvec.size != N:
raise ValueError('`colorvec` has to have length `N`!')
else:
colorvec = np.ones((N,))
# Same for `sizevec`
if sizevec is not None:
arrcheck(sizevec,'vector','sizevec',bounds=[0,np.inf])
if sizevec.size != N:
raise ValueError('`sizevec` has to have length `N`!')
else:
sizevec = np.ones((N,))
# Check labels (if any provided)
if labels is not None:
try:
bad = (len(labels) != N)
except:
raise TypeError("Nodal labels have to be provided as list/NumPy 1darray!")
if bad:
raise ValueError("Number of nodes and labels does not match up!")
for lb in labels:
if not isinstance(lb,(str,unicode)):
raise ValueError('Each individual label has to be a string type!')
else:
labels = []
# Check the colormaps
if type(nodecmap).__name__ != 'LinearSegmentedColormap':
raise TypeError('Nodal colormap has to be a Matplotlib colormap!')
if type(edgecmap).__name__ != 'LinearSegmentedColormap':
raise TypeError('Edge colormap has to be a Matplotlib colormap!')
# If no linewidths were provided, use the entries of `A` as to control edge thickness
if linewidths is not None:
arrcheck(linewidths,'matrix','linewidths')
(ln,lm) = linewidths.shape
if linewidths.shape != A.shape:
raise ValueError("Linewidths must be provided as square array of the same dimension as the connection matrix!")
else:
linewidths = A
# Make sure `nodes3d` is Boolean
if not isinstance(nodes3d,bool):
raise TypeError('The nodes3d flag has to be a Boolean variable!')
# Check if `viewtype` is anything strange
if not isinstance(viewtype,(str,unicode)):
raise TypeError("Viewtype must be 'axial(_{t/b})', 'sagittal(_{l/r})' or 'coronal(_{f/b})'")
# Turn on 3d projection
ax = plt.gcf().gca(projection='3d')
ax.hold(True)
# Extract nodal x-, y-, and z-coordinates from the coords-dictionary
x = np.array([coords[i][0] for i in coords.keys()])
y = np.array([coords[i][1] for i in coords.keys()])
z = np.array([coords[i][2] for i in coords.keys()])
# Order matters here: FIRST plot connections, THEN nodes on top of connections (looks weird otherwise)
# Cycle through the matrix and plot every single connection line-by-line (this is *really* slow)
for i in xrange(N):
for j in xrange(i+1,N):
if A[i,j] > 0:
plt.plot([x[i],x[j]],[y[i],y[j]],[z[i],z[j]],linewidth=linewidths[i][j],color=edgecmap(A[i][j]))
# Plot nodes (either 3d spheres or flat scatter points)
if nodes3d == False:
plt.scatter(x,y,zs=z,marker='o',s=sizevec,c=colorvec,cmap=nodecmap)
else:
n = 20#10
theta = np.arange(-n,n+1,2)/n*np.pi
phi = np.arange(-n,n+1,2)/n*np.pi/2
cosphi = np.cos(phi); cosphi[0] = 0; cosphi[-1] = 0
sinth = np.sin(theta); sinth[0] = 0; sinth[-1] = 0
xx = np.outer(cosphi,np.cos(theta))
yy = np.outer(cosphi,sinth)
zz = np.outer(np.sin(phi),np.ones((n+1,)))
for i in xrange(x.size):
rd = sizevec[i]
ax.plot_surface(rd*xx+x[i],rd*yy+y[i],rd*zz+z[i],\
color = nodecmap(colorvec[i]),\
cstride=1,rstride=1,linewidth=0)
# Label nodes if wanted
for i in xrange(len(labels)):
ax.text(x[i]+2,y[i]+2,z[i]+2,labels[i],color='k',fontsize=14)
# If `viewtype` was specified as 'axial', 'coronal' or 'sagittal' use default (top, front, left) viewtypes
if viewtype == 'axial':
viewtype = 'axial_t'
elif viewtype == 'sagittal' or viewtype == 'sagital':
viewtype = 'sagittal_l'
elif viewtype == 'coronal':
viewtype = 'coronal_f'
# Turn off axis (don't really mean anything in this context anyway...) and set up view
if viewtype == 'axial_t':
ax.view_init(elev=90,azim=-90)
elif viewtype == 'axial_b':
ax.view_init(elev=-90,azim=90)
elif viewtype == 'coronal_f':
ax.view_init(elev=0,azim=90)
elif viewtype == 'coronal_b':
ax.view_init(elev=0,azim=270)
elif viewtype == 'sagittal_l':
ax.view_init(elev=0,azim=180)
elif viewtype == 'sagittal_r':
ax.view_init(elev=0,azim=0)
else:
print "WARNING: Unrecognized viewtype: "+viewtype
print "Using default viewtype `axial` instead"
ax.view_init(elev=90,azim=-90)
plt.axis('scaled')
plt.axis('off')
return
##########################################################################################
[docs]def generate_randnws(nw,M,method="auto",rwr=5,rwr_max=10):
"""
Generate random networks given a(n) (un)signed (un)weighted (un)directed input network
Parameters
----------
nw : NumPy 2darray
Connection matrix of input network
M : int
Number of random networks to generate (> 1)
method : str
String specifying which method to use to randomize
the input network. Currently supported options are
`'auto'` (default), `'null_model_und_sign'`, `'randmio_und'`, `'randmio_und_connected'`,
`'null_model_dir_sign'`, `'randmio_dir'`, `'randmio_dir_connected'`,
`'randmio_und_signed'`, `'randmio_dir_signed'`,
If `method = 'auto'` then a randomization strategy is chosen based
the the properties of the input network (directedness, edge-density, sign of
edge weights). In case of very dense networks (density > 75%) the `null_model`
routines are used to at least shuffle the input network's edge weights.
rwr : int
Number of approximate rewirings per edge (default: 5).
rwr_max : int
Maximal number of rewirings per edge to enforce randomization (default: 10).
Note that `rwr_max` has to be greater or equals `rwr`.
Returns
-------
rnws : NumPy 3darray
Random networks based on input graph `nw`. Format is `rnws.shape = (N,N,M)`
such that `rnws[:,:,m] = m-th N x N` random network
Notes
-----
This routine calls functions from the Brain Connectivity Toolbox (BCT) for MATLAB via Octave.
Thus, it requires Octave to be installed with the BCT in its search path. Further,
`oct2py` is needed to launch an Octave instance from within Python.
See also
--------
randmio_und_connected : in the Brain Connectivity Toolbox (BCT) for MATLAB, currently available
`here <https://sites.google.com/site/bctnet/>`_
randmio_dir_connected : in the Brain Connectivity Toolbox (BCT) for MATLAB, currently available
`here <https://sites.google.com/site/bctnet/>`_
randmio_und : in the Brain Connectivity Toolbox (BCT) for MATLAB, currently available
`here <https://sites.google.com/site/bctnet/>`_
randmio_dir : in the Brain Connectivity Toolbox (BCT) for MATLAB, currently available
`here <https://sites.google.com/site/bctnet/>`_
randmio_und_signed : in the Brain Connectivity Toolbox (BCT) for MATLAB, currently available
`here <https://sites.google.com/site/bctnet/>`_
randmio_dir_signed : in the Brain Connectivity Toolbox (BCT) for MATLAB, currently available
`here <https://sites.google.com/site/bctnet/>`_
null_model_und_sign : in the Brain Connectivity Toolbox (BCT) for MATLAB, currently available
`here <https://sites.google.com/site/bctnet/>`_
null_model_dir_sign : in the Brain Connectivity Toolbox (BCT) for MATLAB, currently available
`here <https://sites.google.com/site/bctnet/>`_
"""
# Try to import `octave` from `oct2py`
try:
from oct2py import octave
except:
errmsg = "Could not import octave from oct2py! "+\
"To use this routine octave must be installed and in the search path. "+\
"Furthermore, the Brain Connectivity Toolbox (BCT) for MATLAB must be installed "+\
"in the octave search path. "
raise ImportError(errmsg)
# Check the two mandatory inputs
arrcheck(nw,'matrix','nw')
N = nw.shape[0]
scalarcheck(M,'M',kind='int',bounds=[1,np.inf])
# See if the string `method` is one of the supported randomization algorithms
supported = ["auto","randmio_und_connected","randmio_und","null_model_und_sign",\
"randmio_dir_connected","randmio_dir","null_model_dir_sign",\
"randmio_und_signed","randmio_dir_signed"]
if supported.count(method) == 0:
msg = 'Network cannot be randomized with `'+str(method)+\
'`. Available options are: '+''.join(supp+', ' for supp in supported)[:-2]
raise ValueError(msg)
# See if `rwr` makes sense
scalarcheck(rwr,'rwr',kind='int',bounds=[1,np.inf])
# Now `rwr_max`
scalarcheck(rwr,'rwr_max',kind='int',bounds=[rwr,np.inf])
# Try to import progressbar module
try:
import progressbar as pb
showbar = True
except:
print "WARNING: progressbar module not found - consider installing it using `pip install progressbar`"
showbar = False
# Allocate space for random networks
rnws = np.empty((N,N,M))
rnw = np.empty((N,N))
rw = rwr
# Unless the user explicitly specified a randomization strategy, choose one based on the
# input network's properties
if method == "auto":
min_nw = nw.min()
sgds = ["unsigned","signed"][min_nw<0]
if issym(nw): # undirected graphs
drct = "undirected"
dns = density_und(nw)
if dns > 0.75:
randomizer = octave.null_model_und_sign
else:
if min_nw < 0:
randomizer = octave.randmio_und_signed
else:
randomizer = octave.randmio_und
else: # directed graphs
drct = "directed"
dns = octave.density_dir(nw)
if dns > 0.75:
randomizer = octave.null_model_dir_sign
else:
if min_nw < 0:
randomizer = octave.randmio_dir_signed
else:
randomizer = octave.randmio_dir
print "Input network is "+drct+" and "+sgds+" with an edge-density of "+str(np.round(1e2*dns))+"%. "+\
"Using `"+randomizer.__name__+"` for randomization..."
# Depending on whether the chosen randomizer returns effective re-wiring numbers, a slightly different
# while loop structure is necessary
use_nm = randomizer.__name__.find('null_model') >= 0
# If available, initialize progressbar
if (showbar):
widgets = ['Calculating Random Networks: ',pb.Percentage(),' ',pb.Bar(marker='#'),' ',pb.ETA()]
pbar = pb.ProgressBar(widgets=widgets,maxval=M)
# Populate tensor
if (showbar): pbar.start()
if use_nm:
for m in xrange(M):
rwr = rw
ok = False
while rwr <= rwr_max and ok == False:
rnw = randomizer(nw,rwr,1)
ok = not np.allclose(rnw,nw)
rwr += 1
if not ok:
print "WARNING: network "+str(m)+" has not been randomized!"
rnws[:,:,m] = rnw.copy()
if (showbar): pbar.update(m)
else:
for m in xrange(M):
rwr = rw
eff = 0
while rwr <= rwr_max and eff == 0:
rnw,eff = randomizer(nw,rwr)
rwr += 1
if eff == 0:
print "WARNING: network "+str(m)+" has not been randomized!"
rnws[:,:,m] = rnw.copy()
if (showbar): pbar.update(m)
if (showbar): pbar.finish()
return rnws
##########################################################################################
[docs]def hdfburp(f):
"""
Pump out everything stored in a HDF5 container
Parameters
----------
f : h5py file object
File object created using `h5py.File()`
Returns
-------
Nothing : None
Notes
-----
This function takes an `h5py`-file object and creates variables in the caller's
local name-space corresponding to the respective dataset-names in the file.
The naming format of the generated variables is `groupname_datasetname`,
where the `groupname` is empty for datasets in the `root` directory of the file.
Thus, if a HDF5 file contains the datasets
`/a`
`/b`
`/group1/c`
`/group1/d`
`/group2/a`
`/group2/b`
then this routine creates the variables
`a`
`b`
`group1_c`
`group1_d`
`group2_a`
`group2_b`
in the caller's workspace.
The black magic part of the code was taken from Pykler's answer to
`this stackoverflow question <http://stackoverflow.com/questions/2515450/injecting-variables-into-the-callers-scope>`_
WARNING: EXISTING VARIABLES IN THE CALLER'S WORKSPACE ARE MERCILESSLY OVERWRITTEN!!!
See also
--------
h5py : a Pythonic interface to the HDF5 binary data format.
"""
# Sanity checks
if str(f).find('HDF5 file') < 0:
raise TypeError('Input must be a valid HDF5 file identifier!')
# Initialize necessary variables
mymap = {}
grplist = [f]
nameprefix = ''
# As long as we find groups in the file, keep iterating
while len(grplist) > 0:
# Get current group (in the first iteration, that's just the file itself)
mygrp = grplist[0]
# If it actually is a group, extract the group name to prefix to variable names
if len(mygrp.name) > 1:
nameprefix = mygrp.name[1::]+'_'
# Iterate through group items
for it in mygrp.items():
# If the current item is a group, add it to the list and keep going
if str(it[1]).find('HDF5 group') >= 0:
grplist.append(f[it[0]])
# If we found a variable, name it following the scheme: `groupname_varname`
else:
varname = nameprefix + it[0]
mymap[varname] = it[1].value
# Done with the current group, pop it from list
grplist.pop(grplist.index(mygrp))
# Update caller's variable scope (this is black magic...)
stack = inspect.stack()
try:
locals_ = stack[1][0].f_locals
finally:
del stack
locals_.update(mymap)
##########################################################################################
def normalize_time_series(time_series_array):
"""
Normalizes a (real/complex) time series to zero mean and unit variance.
WARNING: Modifies the given array in place!
Parameters
----------
time_series_array : NumPy 2d array
Array of data values per time point. Format is: `timepoints`-by-`N`
Returns
-------
Nothing : None
Notes
-----
This function does *not* do any error checking and assumes you know what you are doing
This function is part of the `pyunicorn` package, developed by
Jonathan F. Donges and Jobst Heitzig. The package is currently available
`here <http://www.pik-potsdam.de/~donges/pyunicorn/index.html>`_
See also
--------
pyunicorn : A UNIfied COmplex Network and Recurrence aNalysis toolbox
Examples
--------
>>> ts = np.arange(16).reshape(4,4).astype("float")
>>> normalize_time_series(ts)
>>> ts.mean(axis=0)
array([ 0., 0., 0., 0.])
>>> ts.std(axis=0)
array([ 1., 1., 1., 1.])
>>> ts[:,0]
array([-1.34164079, -0.4472136 , 0.4472136 , 1.34164079])
"""
# Remove mean value from time series at each node (grid point)
time_series_array -= time_series_array.mean(axis=0)
# Normalize the variance of anomalies to one
time_series_array /= np.sqrt((time_series_array *
time_series_array.conjugate()).mean(axis=0))
# Correct for grid points with zero variance in their time series
time_series_array[np.isnan(time_series_array)] = 0
##########################################################################################
[docs]def mutual_info(tsdata, n_bins=32, normalized=True, fast=True, norm_ts=True):
"""
Calculate a (normalized) mutual information matrix at zero lag
Parameters
----------
tsdata : NumPy 2d array
Array of data values per time point. Format is: `timepoints`-by-`N`. Note that
both `timepoints` and `N` have to be `>= 2` (i.e., the code needs at least two time-series
of minimum length 2)
n_bins : int
Number of bins for estimating probability distributions
normalized : bool
If `True`, the normalized mutual information (NMI) is computed
otherwise the raw mutual information (not bounded from above) is calculated
(see Notes for details).
fast : bool
Use C++ code to calculate (N)MI. If `False`, then
a (significantly) slower Python implementation is employed
(provided in case the compilation of the C++ code snippets
fails on a system)
norm_ts : bool
If `True` the input time-series is normalized to zero mean and unit variance (default).
Returns
-------
mi : NumPy 2d array
`N`-by-`N` matrix of pairwise (N)MI coefficients of the input time-series
Notes
-----
For two random variables :math:`X` and :math:`Y` the raw mutual information
is given by
.. math:: MI(X,Y) = H(X) + H(Y) - H(X,Y),
where :math:`H(X)` and :math:`H(Y)` denote the Shannon entropies of
:math:`X` and :math:`Y`, respectively, and :math:`H(X,Y)` is their joint
entropy. By default, this function normalizes the raw mutual information
:math:`MI(X,Y)` by the geometric mean of :math:`H(X)` and :math:`H(Y)`
.. math:: NMI(X,Y) = {MI(X,Y)\over\sqrt{H(X)H(Y)}}.
The heavy lifting in this function is mainly done by code parts taken from
the `pyunicorn` package, developed by Jonathan F. Donges
and Jobst Heitzig [1]_. It is currently available
`here <http://www.pik-potsdam.de/~donges/pyunicorn/index.html>`_
The code has been modified so that weave and pure Python codes are now
part of the same function. Further, the original code computes the raw mutual information
only. Both Python and C++ parts have been extended to compute a normalized
mutual information too.
See also
--------
pyunicorn.pyclimatenetwork.mutual_info_climate_network : classes in this module
Examples
--------
>>> tsdata = np.random.rand(150,2) # 2 time-series of length 150
>>> NMI = mutual_info(tsdata)
References
----------
.. [1] Copyright (C) 2008-2015, Jonathan F. Donges (Potsdam-Institute for Climate
Impact Research), pyunicorn authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of pyunicorn authors and the Potsdam-Institute for
Climate Impact Research nor the names of its contributors may be used to
endorse or promote products derived from this software without specific
prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
# Sanity checks (`tsdata` is probably not square, that's why we don't use `arrcheck` here)
try:
shtsdata = tsdata.shape
except:
raise TypeError('Input must be a timepoint-by-index NumPy 2d array, not '+type(tsdata).__name__+'!')
if len(shtsdata) != 2:
raise ValueError('Input must be a timepoint-by-index NumPy 2d array')
if (min(shtsdata)==1):
raise ValueError('At least two time-series/two time-points are required to compute (N)MI!')
if not np.issubdtype(tsdata.dtype, np.number) or not np.isreal(tsdata).all():
raise TypeError("Input must be real-valued!")
if np.isfinite(tsdata).min() == False:
raise ValueError('Input must be a real valued NumPy 2d array without Infs or NaNs!')
scalarcheck(n_bins,'n_bins',kind='int',bounds=[2,np.inf])
n_bins = int(n_bins)
for bvar in [normalized,fast,norm_ts]:
if not isinstance(bvar,bool):
raise TypeError('The flags `normalized`, `fast` and `norm_ts` must be Boolean!')
# Get faster reference to length of time series = number of samples
# per grid point.
(n_samples,N) = tsdata.shape
# Normalize `tsdata` time series to zero mean and unit variance
if norm_ts:
normalize_time_series(tsdata)
# Initialize mutual information array
mi = np.zeros((N,N), dtype="float32")
# Execute C++ code
if (fast):
# Create local transposed copy of `tsdata`
tsdata = np.fastCopyAndTranspose(tsdata)
# Get common range for all histograms
range_min = float(tsdata.min())
range_max = float(tsdata.max())
# Re-scale all time series to the interval [0,1],
# using the maximum range of the whole dataset.
denom = range_max - range_min + 1 - (range_max != range_min)
scaling = float(1. / denom)
# Create array to hold symbolic trajectories
symbolic = np.empty(tsdata.shape, dtype="int32")
# Initialize array to hold 1d-histograms of individual time series
hist = np.zeros((N,n_bins), dtype="int32")
# Initialize array to hold 2d-histogram for one pair of time series
hist2d = np.zeros((n_bins,n_bins), dtype="int32")
# C++ code to compute NMI
code_nmi = r"""
int i, j, k, l, m;
int symbol, symbol_i, symbol_j;
double norm, rescaled, hpl, hpm, plm, Hl, Hm;
// Calculate histogram norm
norm = 1.0 / n_samples;
for (i = 0; i < N; i++) {
for (k = 0; k < n_samples; k++) {
// Calculate symbolic trajectories for each time series,
// where the symbols are bins.
rescaled = scaling * (tsdata(i,k) - range_min);
if (rescaled < 1.0) {
symbolic(i,k) = rescaled * n_bins;
}
else {
symbolic(i,k) = n_bins - 1;
}
// Calculate 1d-histograms for single time series
symbol = symbolic(i,k);
hist(i,symbol) += 1;
}
}
for (i = 0; i < N; i++) {
for (j = 0; j <= i; j++) {
// The case `i = j` is not of interest here!
if (i != j) {
// Calculate 2d-histogram for one pair of time series
// (i,j).
for (k = 0; k < n_samples; k++) {
symbol_i = symbolic(i,k);
symbol_j = symbolic(j,k);
hist2d(symbol_i,symbol_j) += 1;
}
// Calculate mutual information for one pair of time
// series (i,j).
Hl = 0;
for (l = 0; l < n_bins; l++) {
hpl = hist(i,l) * norm;
if (hpl > 0.0) {
Hl += hpl * log(hpl);
Hm = 0;
for (m = 0; m < n_bins; m++) {
hpm = hist(j,m) * norm;
if (hpm > 0.0) {
Hm += hpm * log(hpm);
plm = hist2d(l,m) * norm;
if (plm > 0.0) {
mi(i,j) += plm * log(plm/hpm/hpl);
}
}
}
}
}
// Divide by the marginal entropies to normalize MI
mi(i,j) = mi(i,j) / sqrt(Hm * Hl);
// Symmetrize MI
mi(j,i) = mi(i,j);
// Reset `hist2d` to zero in all bins
for (l = 0; l < n_bins; l++) {
for (m = 0; m < n_bins; m++) {
hist2d(l,m) = 0;
}
}
}
// Put ones on the diagonal
else {
mi(i,j) = 1.0;
}
}
}
"""
# C++ code to compute MI
code_mi = r"""
int i, j, k, l, m;
int symbol, symbol_i, symbol_j;
double norm, rescaled, hpl, hpm, plm;
// Calculate histogram norm
norm = 1.0 / n_samples;
for (i = 0; i < N; i++) {
for (k = 0; k < n_samples; k++) {
// Calculate symbolic trajectories for each time series,
// where the symbols are bins.
rescaled = scaling * (tsdata(i,k) - range_min);
if (rescaled < 1.0) {
symbolic(i,k) = rescaled * n_bins;
}
else {
symbolic(i,k) = n_bins - 1;
}
// Calculate 1d-histograms for single time series
symbol = symbolic(i,k);
hist(i,symbol) += 1;
}
}
for (i = 0; i < N; i++) {
for (j = 0; j <= i; j++) {
// The case i = j is not of interest here!
if (i != j) {
// Calculate 2d-histogram for one pair of time series
// `(i,j)`.
for (k = 0; k < n_samples; k++) {
symbol_i = symbolic(i,k);
symbol_j = symbolic(j,k);
hist2d(symbol_i,symbol_j) += 1;
}
// Calculate mutual information for one pair of time
// series `(i,j)`.
// Hl = 0;
for (l = 0; l < n_bins; l++) {
hpl = hist(i,l) * norm;
if (hpl > 0.0) {
// `Hl += hpl * log(hpl);`
// `Hm = 0;`
for (m = 0; m < n_bins; m++) {
hpm = hist(j,m) * norm;
if (hpm > 0.0) {
// `Hm += hpm * log(hpm);`
plm = hist2d(l,m) * norm;
if (plm > 0.0) {
mi(i,j) += plm * log(plm/hpm/hpl);
}
}
}
}
}
// Symmetrize MI
mi(j,i) = mi(i,j);
// Reset `hist2d` to zero in all bins
for (l = 0; l < n_bins; l++) {
for (m = 0; m < n_bins; m++) {
hist2d(l,m) = 0;
}
}
}
// Put ones on the diagonal
else {
mi(i,j) = 1.0;
}
}
}
"""
# Initialize necessary variables to pass on to C++ code snippets above
vars = ['tsdata', 'n_samples', 'N', 'n_bins', 'scaling', 'range_min',
'symbolic', 'hist', 'hist2d', 'mi']
# Compute normalized or non-normalized mutual information
if (normalized):
weave.inline(code_nmi, vars, type_converters=weave.converters.blitz,
compiler='gcc', extra_compile_args=['-O3'])
else:
weave.inline(code_mi, vars, type_converters=weave.converters.blitz,
compiler='gcc', extra_compile_args=['-O3'])
# Python version of (N)MI computation (slower)
else:
# Define references to NumPy functions for faster function calls
histogram = np.histogram
histogram2d = np.histogram2d
log = np.log
# Get common range for all histograms
range_min = tsdata.min()
range_max = tsdata.max()
# Calculate the histograms for each time series
p = np.zeros((N, n_bins))
for i in xrange(N):
p[i,:] = (histogram(tsdata[:, i], bins=n_bins,
range=(range_min,range_max))[0]).astype("float64")
# Normalize by total number of samples = length of each time series
p /= n_samples
# Make sure that bins with zero estimated probability are not counted
# in the entropy measures.
p[p == 0] = 1
# Compute the information entropies of each time series
H = - (p * log(p)).sum(axis = 1)
# Calculate only the lower half of the MI matrix, since MI is
# symmetric with respect to `X` and `Y`.
for i in xrange(N):
for j in xrange(i):
# Calculate the joint probability distribution
pxy = (histogram2d(tsdata[:,i], tsdata[:,j], bins=n_bins,
range=((range_min, range_max),
(range_min, range_max)))[0]).astype("float64")
# Normalize joint distribution
pxy /= n_samples
# Compute the joint information entropy
pxy[pxy == 0] = 1
HXY = - (pxy * log(pxy)).sum()
# Normalize by entropies (or not)
if (normalized):
mi.itemset((i,j), (H.item(i) + H.item(j) - HXY)/(np.sqrt(H.item(i)*H.item(j))))
else:
mi.itemset((i,j), H.item(i) + H.item(j) - HXY)
# Symmetrize MI
mi.itemset((j,i), mi.item((i,j)))
# Put ones on the diagonal
np.fill_diagonal(mi,1)
# Return (N)MI matrix
return mi
##########################################################################################
[docs]def issym(A,tol=1e-9):
"""
Check for symmetry of a 2d NumPy array
Parameters
----------
A : NumPy 2darray
A presumably symmetric matrix
tol : float
Tolerance :math:`\\tau` for checking if :math:`A` is sufficiently close to :math:`A^\\top`.
Returns
-------
is_sym : bool
True if :math:`A` satisfies :math:`|A - A^\\top| \\leq \\tau |A|`,
where :math:`|\\cdot|` denotes the Frobenius norm. Thus, if this inequality
holds, :math:`A` is approximately symmetric.
Notes
-----
For further details regarding the Frobenius norm approach used, please refer to the
discussion in `this <http://www.mathworks.com/matlabcentral/newsreader/view_thread/252727>`_
thread at MATLAB central
See also
--------
isclose : An absolute-value based comparison readily provided by NumPy.
"""
# Check if Frobenius norm of `A - A.T` is sufficiently small (respecting round-off errors)
try:
is_sym = (norm(A-A.T,ord='fro') <= tol*norm(A,ord='fro'))
except:
raise TypeError('Input argument has to be a square matrix/array and a scalar tol (optional)!')
return is_sym
##########################################################################################
[docs]def printdata(data,leadrow,leadcol,fname=None):
"""
Pretty-print/-save array-like data
Parameters
----------
data : NumPy 2darray
An `M`-by-`N` array of data
leadrow : Python list or NumPy 1darray
List/array of length `N` or `N+1` providing labels to be printed in the first row of the table
(strings/numerals or both). See Examples for details
leadcol : Python list or NumPy 1darray
List/array of length `M` providing labels to be printed in the first column of the table
(strings/numerals or both)
fname : str
Name of a csv-file (with or without extension `.csv`) used to save the table
(WARNING: existing files will be overwritten!). Can also be a path + filename
(e.g., `fname='path/to/file.csv'`). By default output is not saved.
Returns
-------
Nothing : None
Notes
-----
Uses the `texttable` module to print results
Examples
--------
>>> import numpy as np
>>> data = np.random.rand(2,3)
>>> row1 = ["a","b",3]
>>> col1 = np.arange(2)
>>> printdata(data,row1,col1)
+--------------------+--------------------+--------------------+--------------------+
| | a | b | 3 |
+====================+====================+====================+====================+
| 0 | 0.994018537964 | 0.707532139166 | 0.767497407803 |
+--------------------+--------------------+--------------------+--------------------+
| 1 | 0.914193045048 | 0.758181936461 | 0.216752553325 |
+--------------------+--------------------+--------------------+--------------------+
>>> row1 = ["labels"] + row1
>>> printdata(data,row1,col1,fname='dummy')
+--------------------+--------------------+--------------------+--------------------+
| labels | a | b | 3 |
+====================+====================+====================+====================+
| 0 | 0.994018537964 | 0.707532139166 | 0.767497407803 |
+--------------------+--------------------+--------------------+--------------------+
| 1 | 0.914193045048 | 0.758181936461 | 0.216752553325 |
+--------------------+--------------------+--------------------+--------------------+
>>> cat dummy.csv
labels, a, b, 3
0,0.994018537964,0.707532139166,0.767497407803
1,0.914193045048,0.758181936461,0.216752553325
See also
--------
texttable : a module for creating simple ASCII tables (currently available at the
`Python Package Index <https://pypi.python.org/pypi/texttable/0.8.1>`_)
"""
# Try to import Texttable object
try: from texttable import Texttable
except:
raise ImportError("Could not import texttable! Consider installing it using pip install texttable")
# Check dimensions of input
try:
ds = data.shape
except:
raise TypeError('Input must be a M-by-N NumPy array, not ' + type(data).__name__+'!')
if len(ds) > 2:
raise ValueError('Input must be a M-by-N NumPy array!')
for lvar in [leadcol,leadrow]:
if not isinstance(lvar,(list,np.ndarray)):
raise TypeError("The inputs `leadcol` and `leadrow` must by Python lists or Numpy 1d arrays!")
if len(np.array(lvar).squeeze().shape) != 1:
raise ValueError("The inputs `leadcol` and `leadrow` must 1-d lists/arrays!")
m = len(leadcol)
n = len(leadrow)
# If a filename was provided make sure it's a string and check if the path exists
if fname is not None:
if not isinstance(fname,(str,unicode)):
raise TypeError('Optional output filename has to be a string!')
fname = str(fname)
if fname.find("~") == 0:
fname = os.path.expanduser('~') + fname[1:]
slash = fname.rfind(os.sep)
if slash >= 0 and not os.path.isdir(fname[:fname.rfind(os.sep)]):
raise ValueError('Invalid path for output file: '+fname+'!')
if fname[-4::] != '.csv':
fname = fname + '.csv'
save = True
else: save = False
# Get dimension of data and corresponding leading column/row
if len(ds) == 1:
K = ds[0]
if K == m:
N = 1; M = K
elif K == n or K == (n-1):
M = 1; N = K
else:
raise ValueError('Number of elements in heading column/row and data don not match up!')
data = data.reshape((M,N))
else:
M,N = ds
if M != m:
raise ValueError('Number of rows and no. of elements leading column do not match up!')
elif N == n:
head = [' '] + list(leadrow)
elif N == (n-1):
head = list(leadrow)
else:
raise ValueError('Number of columns and no. of elements in head row do not match up!')
# Do something: create big data array including leading column
Data = np.column_stack((leadcol,data.astype('str')))
# Initialize table object and fill it with stuff
table = Texttable()
table.set_cols_align(["l"]*(N+1))
table.set_cols_valign(["c"]*(N+1))
table.set_cols_dtype(["t"]*(N+1))
table.set_cols_width([18]*(N+1))
table.add_rows([head],header=True)
table.add_rows(Data.tolist(),header=False)
# Pump out table
print table.draw() + "\n"
# If wanted, save stuff in a csv file
if save:
np.savetxt(fname,Data,delimiter=",",fmt="%s",header="".join(str(hd)+", " for hd in head)[:-2],comments="")
return
##########################################################################################
[docs]def img2vid(imgpth,imgfmt,outfile,fps,filesize=None,ext='mp4',preset='veryslow'):
"""
Convert a sequence of image files to a video using ffmpeg
Parameters
----------
imgpth : str
Path to image files
imgfmt : str
Format specifier for images. All files in the image stack have to follow the same naming
convention, e.g., given the sequence `im_01.png`, `im_02.png`, ...,`im_99.png` the correct
format specifier `imgfmt` is 'im_%02d.png'
outfile : str
Filename (including path if not in current directory) for output video. If an extension
is provided, e.g., 'animation.mp4' it is passed on to the x264 video encoder in
ffmpeg to set the video format of the output. Use `ffmpeg -formats` in a shell
to get a list of supported formats (any format labeled 'Muxing supported').
fps : int
Framerate of the video (number of frames per second)
filesize : float
Target size of video file in MB (Megabytes).
If provided, a encoding bitrate will be chosen such that the target size
`filesize` is not exceeded. If `filesize = None` the default constant rate factor of ffmpeg
is used (the longer the movie, the larger the generated file).
ext : str
Extension of the video-file. If `outfile` does not have a filetype extension, then
the default value of `ext` is used and an mp4 video is generated. Note: if `outfile`
has an extension, then any value of `ext` will be ignored. Use `ffmpeg -formats` in a shell
to get a list of supported formats (any format labeled 'Muxing supported').
preset : str
Video quality options for ffmpeg's x264 encoder controlling the encoding speed to
compression ratio. A slower preset results in better compression (higher quality
per filesize) but longer encoding time. Available presets in ffmpeg are
'ultrafast', 'superfast', 'veryfast', 'faster', 'fast', 'medium', 'slow', 'slower', 'veryslow', and 'placebo'.
Returns
-------
Nothing : None
Examples
--------
Suppose the 600 sequentially numbered tiff files `image_001.tiff`, `image_002.tiff`, ..., `image_600.tiff`
located in the directory `/path/to/images/ ` have to be converted to Quicktime movie (mov file) of no more than 25MB size.
We want the video to show 6 consecutive images per second (i.e., a framerate of 6 frames per second).
This can be done using the following command
>>> img2vid('/path/to/images','image_%03d.tiff','image_1_600',6,filesize=25,ext='mov',preset='veryslow')
Alternatively,
>>> img2vid('/path/to/images','image_%03d.tiff','image_1_600_loq.mov',6,filesize=25,ext='mkv',preset='ultrafast')
also generates an mov video of 25MB. The encoding will be faster but the image quality of `image_1_600_loq.mov`
will be lower compared to `image_1_600.mov` generated by the first call. Note that the optional keyword argument
`ext='mkv'` is ignored since the provided output filename 'image_1_600_loq.mov' already contains an extension.
"""
# First and foremost, check if ffmpeg is available, otherwise everything else is irrelevant
if os.system("which ffmpeg > /dev/null") != 0:
msg = "Could not find ffmpeg. It seems like ffmpeg is either not installed or not in the search path. "
raise ValueError(msg)
# Check if image directory exists and append trailing slash if necessary
if not isinstance(imgpth,(str,unicode)):
raise TypeError('Path to image directory has to be a string!')
imgpth = str(imgpth)
if imgpth.find("~") == 0:
imgpth = os.path.expanduser('~') + imgpth[1:]
if not os.path.isdir(imgpth):
raise ValueError('Invalid path to image directory: '+imgpth+'!')
slash = imgpth.rfind(os.sep)
if slash != len(imgpth)-1:
imgpth += os.sep
# Check if `imgfmt` is a valid string format specifier
# (don't use split below in case we have something like im.001.tiff)
if not isinstance(imgfmt,(str,unicode)):
raise TypeError('Format specifier for images has to be a string!')
imgfmt = str(imgfmt)
dot = imgfmt.rfind('.')
fmt = imgfmt[:dot]
imtype = imgfmt[dot+1:]
if fmt.find('%') < 0: raise ValueError('Invalid image format specifier: `'+fmt+'`!')
# Check if image directory actually contains any images of the given type
imgs = natsort.natsorted(myglob(imgpth,'*.'+imtype), key=lambda y: y.lower())
num_imgs = len(imgs)
if num_imgs < 2: raise ValueError('Directory '+imgpth+' contains fewer than 2 `'+imtype+'` files!')
# Check validity of `outfile`
if not isinstance(outfile,(str,unicode)):
raise TypeError('Output filename has to be a string!')
outfile = str(outfile)
if outfile.find("~") == 0:
outfile = os.path.expanduser('~') + outfile[1:]
slash = outfile.rfind(os.sep)
if slash >= 0 and not os.path.isdir(outfile[:outfile.rfind(os.sep)]):
raise ValueError('Invalid path to save movie: '+outfile+'!')
# Check format specifier for the movie: the if loop separates filename from extension
# (use split here to prevent the user from creating abominations like `my.movie.mp4`)
dot = outfile.rfind('.')
if dot == 0: raise ValueError(outfile+' is not a valid filename!') # e.g., outfile = '.name'
if dot == len(outfile) - 1: # e.g., outfile = 'name.'
outfile = outfile[:dot]
dot = -1
if dot > 0: # e.g., outfile = 'name.mp4'
out_split = outfile.split('.')
if len(out_split) > 2: raise ValueError(outfile+' is not a valid filename!')
outfile = out_split[0]
# If outfile had an extension but there was an add'l extension provided, warn the user
if out_split[1] != str(ext) and str(ext) != 'mp4':
print "WARNING: Using extension `"+out_split[1]+"` of output filename, not `"+str(ext)+"`!"
ext = out_split[1]
else: # e.g., outfile = 'name'
if str(ext) != ext:
raise TypeError('Filename extension for movie has to be a string!')
exl = str(ext).split('.')
if len(exl) > 1: raise ValueError(ext+' is not a valid extension for a video file!')
ext = exl[0]
# Make sure `fps` is a positive integer
scalarcheck(fps,'fps',kind='int',bounds=[1,np.inf])
# Check if output filesize makes sense (if provided)
if filesize is not None:
scalarcheck(filesize,'filesize',bounds=[0,np.inf])
# Check if `preset` is valid (if provided)
if not isinstance(preset,(str,unicode)):
raise TypeError('Preset specifier for video encoding has to be a string!')
supported = ['ultrafast','superfast','veryfast','faster','fast','medium','slow','slower','veryslow','placebo']
if supported.count(preset) == 0:
msg = 'Preset `'+preset+'` not supported by ffmpeg. Supported options are: '+\
''.join(supp+', ' for supp in supported)[:-2]
raise ValueError(msg)
# Now let's start to actually do something and set the null device based on which platform we're running on
if os.uname()[0].find('Windows') > 0:
nulldev = 'NUL'
else:
nulldev = '/dev/null'
# Encode movie respecting provided file-size limit
if filesize is not None:
# Calculate movie length based on desired frame-rate and bit-rate such that given filesize is not exceeded (MB->kbit/s uses 8192)
movie_len = np.ceil(num_imgs/fps)
brate = int(np.floor(filesize*8192/movie_len))
# Use two-pass encoding to ensure maximum image quality while keeping the filesize within specified bounds
os.system("ffmpeg -y -framerate "+str(fps)+" -f image2 -i "+imgpth+imgfmt+" "+\
"-vf 'scale=trunc(iw/2)*2:trunc(ih/2)*2' "
"-vcodec libx264 -preset "+preset+" -pix_fmt yuv420p -b:v "+str(brate)+"k -b:a 0k -pass 1 "+\
"-f "+ext+" "+nulldev+" && "+\
"ffmpeg -framerate "+str(fps)+" -f image2 -i "+imgpth+imgfmt+" "+\
"-vcodec libx264 -preset "+preset+" -pix_fmt yuv420p -b:v "+str(brate)+"k -b:a 0k -pass 2 "+\
"-vf 'scale=trunc(iw/2)*2:trunc(ih/2)*2' "+outfile+"."+ext)
# Encode movie with no regards given to final size
else:
# Use a constant rate factor (incompatible with 2-pass encoding) to render the movie
os.system("ffmpeg -framerate "+str(fps)+" -f image2 -i "+imgpth+imgfmt+" "+\
"-vcodec libx264 -preset "+preset+" -pix_fmt yuv420p "+\
"-vf 'scale=trunc(iw/2)*2:trunc(ih/2)*2' "+outfile+"."+ext)
return
##########################################################################################
def build_hive(ax,branches,connections,node_vals=None,center=(0,0),branch_extent=None,positions=None,labels=None,
angle=90,branch_colors=None,branch_alpha=1.0,node_cmap=plt.cm.jet,node_alpha=1.0,\
edge_cmap=plt.cm.jet,edge_alpha=1.0,edge_vrange=[0,1],node_vrange=[0,1],node_sizes=0.01,\
branch_lw=2,edge_lw=0.5,radians=0.15,labelsize=8,node_lw=0.5,nodes3d=False,sphere_res=40,\
lightsource=None,full3d=False,viewpoint=None,ethresh=None,show_grid=False):
"""
By default no threshold is applied to edges, i.e., even edges with zero-weights are drawn using
the respective value from the colormap `edge_cmap`. If you want to remove zero-weight edges use
the keyword argument `ethresh = 0`.
"""
# Define some default values in case the user didn't provide all optional inputs
branch_beg = 0.05 # Start of branches as displacement from `center` (if `branch_extent == None`)
branch_end = 0.95 # Length of branches (if `branch_extent == None`)
pos_offset = 0.05 # Offset percentage for nodes on branches (if `positions == None`)
x_offset = 0.1 # Offset percentage for x-axis limits
y_offset = 0.1 # Offset percentage for y-axis limits
z_offset = 0.1 # Offset percentage for z-axis limits (only relevant if `full3d == True`)
# Error checking for dictionaries with numeric values
def check_dict(dct,name):
try:
for branch in dct.keys():
dct[branch] = np.array(dct[branch])
except:
raise TypeError('The provided '+name+' have to be a dictionary with the same keys as `branches`!')
for branch, nodes in dct.items():
arrcheck(nodes,'vector',name)
if branches[branch].size != nodes.size:
raise ValueError("Provided branches and "+name+" don't match up!")
# Amend `FancyArrowPatch` by 3D capabilities
# (taken from http://stackoverflow.com/questions/11140163/python-matplotlib-plotting-a-3d-cube-a-sphere-and-a-vector)
class Arrow3D(FancyArrowPatch):
def __init__(self, xs, ys, zs, *args, **kwargs):
FancyArrowPatch.__init__(self, (0,0), (0,0), *args, **kwargs)
self._verts3d = xs, ys, zs
def draw(self, renderer):
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))
FancyArrowPatch.draw(self, renderer)
# Check if `ax` is really an mpl axis object
try:
plt.sca(ax)
except:
raise TypeError("Could not make axis "+str(ax)+" active!")
# See if `branches` is a dictionary of branch numbers/labels holding node-numbers
if not isinstance(branches,dict):
raise TypeError('The input `branches` has to be dictionary-like, not '+type(branches).__name__)
try:
for branch in branches.keys():
branches[branch] = np.array(branches[branch]).squeeze()
except:
raise TypeError('Branches must be provided as dictionary of node arrays/lists!!')
for nodes in branches.values():
arrcheck(nodes,'vector','node indices')
branch_arr = np.array(branches.keys())
num_branches = branch_arr.size
if num_branches == 1:
raise ValueError('Only one branch found - no bueno')
if type(branches).__name__ != 'OrderedDict':
branch_arr = np.sort(branch_arr) # if we have a regular dict, sort its keys
node_arr = []
for nodes in branches.values():
node_arr += list(nodes)
node_arr = np.unique(node_arr)
num_nodes = node_arr.size
if np.any(np.diff(node_arr) != 1) or node_arr.min() != 0:
raise ValueError('Node numbers have to be contiguous in ascending order starting with 0!')
for br in branches.keys():
branches[br] = np.array(branches[br])
# See if `connections` is a 2d array that matches the provided branch dictionary
arrcheck(connections,'matrix','connections',bounds=[0,1])
if connections.shape[0] != num_nodes:
raise ValueError('Number of nodes does not match connection matrix!')
# Let's see if we're going to have fun in 3D
if not isinstance(nodes3d,bool):
raise TypeError('Three-dimensional nodes are activated using a binary True/False flag!')
if not isinstance(full3d,bool):
raise TypeError('Full 3D plotting is activated using a binary True/False flag!')
if full3d:
nodes3d = False # just internally: turn off this switch to avoid confusion later on
if not isinstance(show_grid,bool):
raise TypeError('Grid is drawn or not based on a binary True/False flag!')
if show_grid and not full3d:
print "WARNING: 3D grid is only shown for full 3D plots!"
# Now check resolution parameter for rendering spheres
scalarcheck(sphere_res,'sphere_res',kind='int',bounds=[2,np.inf])
if sphere_res >= 100:
print "WARNING: The resolution parameter for nodal spheres is very large - rendering might take forever..."
# See if a light-source for illumination was provided
if lightsource is not None:
if isinstance(lightsource,bool):
if lightsource == True:
lightsource = np.array([90,45])
else:
lightsource = None
if lightsource is not None:
lightsource = np.array(lightsource)
arrcheck(lightsource,'vector','lightsource')
if lightsource.min() < 0 or lightsource[0] > 360 or lightsource[1] > 90:
raise ValueError("Light-source azimuth/elevation has to be between 0-360 and 0-90 degrees, respectively!")
if lightsource.size != 2:
raise ValueError("Light-source has to be provided as azimuth/altitude degrees!")
# See if a threshold for drawing edges was provided, if not, don't use one
if ethresh is not None:
scalarcheck(ethresh,'ethresh',bounds=[0,1])
# See if a camera position (in azimuth/elevation degrees) was provided, if not use some defaults
if viewpoint is not None:
viewpoint = np.array(viewpoint)
if np.issubdtype(viewpoint.dtype, np.number):
arrcheck(viewpoint,'vector','viewpoint')
if viewpoint.size != 2:
raise ValueError("View-point has to be provided as azimuth/altitude degrees!")
else:
raise TypeError("View-point for illumination has to [`azdeg`,`altdeg`]!")
else:
viewpoint = np.array([-60,30])
# See if nodal values were provided, if not create simple dict
if node_vals is not None:
check_dict(node_vals,'nodal values')
for vals in node_vals.values():
if vals.min() < 0 or vals.max() > 1:
raise ValueError('Nodal values must be between zero and one!')
else:
node_vals = {}
for branch, nodes in branches.items():
node_vals[branch] = np.ones(branches[branch].shape)
# See if center makes sense, if provided
try:
center = np.array(center)
except:
raise TypeError('Unsupported type for input `center`: '+type(dict).__name__)
arrcheck(center,'vector','center')
if np.all(center) == 0:
if full3d:
center = np.zeros((3,))
else:
if full3d == False and center.size != 2:
raise ValueError("Center coordinates have to be two-dimensional!")
if full3d == True and center.size != 3:
raise ValueError("For 3D plots center coordinates have to be three-dimensional!")
# See if branch lengths were provided, otherwise construct'em
if branch_extent is not None:
try:
for branch in branches.keys():
branch_extent[branch] = np.array(branch_extent[branch])
except:
raise TypeError("The provided branch dimensions have to be a dictionary with the same keys as `branches`!")
for branch in branches.keys():
arrcheck(branch_extent[branch],'vector','branch dimensions')
if branch_extent[branch].size != 2:
raise ValueError("Only two values by branch supported for branch dimensions!")
if branch_extent[branch][0] >= branch_extent[branch][1]:
raise ValueError("Branch dimensions have to be increasing (beginning -> end)!")
else:
branch_extent = {}
for branch in branches.keys():
branch_extent[branch] = [branch_beg,branch_end]
# See if nodal positions were provided, if not create simple dict
if positions is not None:
check_dict(positions,'nodal positions')
for branch in branches.keys():
if positions[branch].min() < branch_extent[branch][0] or positions[branch].max() > branch_extent[branch][1]:
raise ValueError('Nodal positions on branches must be within branch extent!')
else:
positions = {}
for branch, extent in branch_extent.items():
length = extent[1] - extent[0]
offset = pos_offset*length
positions[branch] = np.linspace(extent[0]+offset,extent[1]-offset,branches[branch].size)
# See if labels were provided and make sense, otherwise don't use labels
if labels is not None:
try:
for branch in branches.keys():
labels[branch] = np.array(labels[branch])
except:
raise TypeError("The provided nodal labels have to be a dictionary with the same keys as `branches`!")
for branch in branches.keys():
if branches[branch].size != labels[branch].size:
raise ValueError("Provided branches and nodal labels don't match up!")
if np.issubdtype(labels[branch].dtype, np.number):
raise ValueError("The provided nodal labels must be strings!")
if full3d or nodes3d:
print "WARNING: Due to limiations in mplot3d the positiong of text in 3d space is somewhat screwed up..."
# Now make sure label font-size makes sense
scalarcheck(labelsize,'labelsize',bounds=[0,np.inf])
# Check branch angle(s) were provided, if not, generate'em
if isinstance(angle,dict):
try:
for branch in branches.keys():
tmp = np.array(angle[branch])
except:
raise TypeError("The provided branch angles have to be a dictionary with the same keys as `branches`!")
for branch in branches.keys():
if full3d:
angle[branch] = np.array(angle[branch])
arrcheck(angle[branch],'vector','3D branch angles')
if len(angle[branch]) != 2:
raise ValueError("3D branch angles must be provided as two values per branch!")
if angle[branch][0] < 0 or angle[branch][0] > 360:
raise ValueError("Azimuth must be between 0 and 360 degrees!")
if angle[branch][1] < -90 or angle[branch][1] > 90:
raise ValueError("Elevation must be between -90 and +90 degrees!")
azim = math.radians(angle[branch][0])
elev = math.radians(angle[branch][1])
elev = np.pi/2 - (elev > 0)*elev + (elev < 0)*np.abs(elev)
angle[branch] = np.array([azim,elev])
else:
if not np.isscalar(angle[branch]) or not np.issubdtype(angle[branch].dtype, np.number) or not np.isreal(angle[branch]).all():
raise TypeError("Branch angles must be real-valued, one value per branch!")
if np.isfinite(angle[branch]) == False:
raise ValueError("Branch angles must not be NaN or Inf!")
if angle[branch] < 0 or angle[branch] > 360:
raise ValueError("Branch angles must be between 0 and 360 degrees!")
angle[branch] = math.radians(ange[branch])
elif np.isscalar(angle):
scalarcheck(angle,'angle',bounds=[0,360])
if full3d:
angle = {}
angle[branch_arr[0]] = np.zeros((2,)) # in spherical coordinates (main branch is vertical line from origin)
start = 1/4*np.pi
degs = np.linspace(start,start+2*np.pi,num_branches) # these are the "azimuth" angles (well, not really...)
elev = 3/4*np.pi
for br, branch in enumerate(branch_arr[1:]): # Here order is important! Use the generated (sorted) array!
angle[branch] = np.array([degs[br],elev])
else:
angle = math.radians(angle)
degs = np.linspace(angle,angle+2*np.pi,num_branches+1)
angle = {}
for br, branch in enumerate(branch_arr): # Here order is important! Use the generated (sorted) array!
angle[branch] = degs[br]
else:
raise TypeError("Branch angles have to be provided either as scalar or dictionary!")
# Check color-values of branches - if not provided, construct'em
if branch_colors is not None:
if isinstance(branch_colors,dict):
for branch in branches.keys():
if len(branch_colors[branch]) > 1:
raise ValueError("Only one color per branch is supported!")
if np.issubdtype(branch_colors[branch].dtype, np.number):
raise ValueError("The provided branch colors must be strings!")
elif isinstance(branch_colors,str):
bc = branch_colors
branch_colors = {}
for branch in branches.keys():
branch_colors[branch] = bc
else:
raise TypeError("The provided branch colors have to be either a string or "+\
"a dictionary with the same keys as `branches`!")
else:
branch_colors = {}
for branch in branches.keys():
branch_colors[branch] = 'Black'
# Check node and edge color maps
for cmap in [node_cmap,edge_cmap]:
if type(cmap).__name__.find('Colormap') < 0:
raise TypeError("Node/Edge colormaps have to be matplotlib colormaps!")
# Check value ranges for nodes and edges
for vrange in [node_vrange,edge_vrange]:
try:
vrange = np.array(vrange)
except:
raise TypeError('Unsupported type for node/edge value ranges: '+type(dict).__name__)
arrcheck(vrange,'vector','node/edge value range',bounds=[0,1])
if vrange.size != 2:
raise ValueError("Node/Edge value range has to be two-dimensional!")
if vrange[0] >= vrange[1]:
raise ValueError("Node/Edge value range must strictly increasing!")
# See if nodal sizes have been provided, if not construct dictionary
if isinstance(node_sizes,dict):
check_dict(node_sizes,'nodal sizes')
for vals in node_sizes.values():
if vals.min() < 0:
raise ValueError("Nodal sizes have to be non-negative!")
elif np.isscalar(node_sizes):
scalarcheck(node_sizes,'node_sizes',bounds=[0,np.inf])
ns = node_sizes
node_sizes = {}
for branch,nodes in branches.items():
node_sizes[branch] = ns*np.ones(nodes.shape)
else:
raise TypeError("Nodal sizes have to be provided either as scalar or dictionary!")
# See if nodal alpha values have been provided, if not construct dictionary
if isinstance(node_alpha,dict):
check_dict(node_alpha,'nodal alpha values')
for vals in node_alpha.values():
if vals.min() < 0 or vals.max() > 1:
raise ValueError("Nodal alpha values have to be between zero and one!")
elif np.isscalar(node_alpha):
scalarcheck(node_alpha,'node_alpha',bounds=[0,1])
ns = node_alpha
node_alpha = {}
for branch,nodes in branches.items():
node_alpha[branch] = ns*np.ones(nodes.shape)
else:
raise TypeError("Nodal alpha values have to be provided either as scalar or dictionary!")
# Now make sure node line-width makes sense
scalarcheck(node_lw,'node_lw',bounds=[0,np.inf])
if full3d:
print "WARNING: Line-width specifications for nodes is ignored for full 3D plots!"
# Check if line-widths for branches have been provided, otherwise assign default values
if isinstance(branch_lw,dict):
try:
for branch in branches.keys():
tmp = np.array(branch_lw[branch])
except:
raise TypeError("The provided branch line-widths have to be a dictionary with the same keys as `branches`!")
for branch in branches.keys():
if not np.isscalar(branch_lw[branch]) or not np.issubdtype(branch_lw[branch].dtype, np.number) or not np.isreal(branch_extent[branch]).all():
raise ValueError("Branch line-widths must be real-valued, one value per branch!")
if np.isfinite(branch_lw[branch]) == False:
raise ValueError("Branch line-widths must not be NaN or Inf!")
if branch_lw[branch] < 0:
raise ValueError("Branch line-widths have to be non-negative!")
elif np.isscalar(branch_lw):
scalarcheck(branch_lw,'branch_lw',bounds=[-0.1,np.inf])
bw = branch_lw
branch_lw = {}
for branch in branches.keys():
branch_lw[branch] = bw
else:
raise TypeError("Branch line-widths have to be provided either as scalar or dictionary!")
# Check if alpha-values for branches have been provided, otherwise assign default values
if isinstance(branch_alpha,dict):
try:
for branch in branches.keys():
tmp = np.array(branch_alpha[branch])
except:
raise TypeError("The provided branch alpha values have to be a dictionary with the same keys as `branches`!")
for branch in branches.keys():
if not np.isscalar(branch_alpha[branch]) or not np.issubdtype(branch_alpha[branch].dtype, np.number) or not np.isreal(branch_extent[branch]).all():
raise ValueError("Branch alpha values must be real-valued, one value per branch!")
if np.isfinite(branch_alpha[branch]) == False:
raise ValueError("Branch alpha values must not be NaN or Inf!")
if branch_alpha[branch] < 0 or branch_alpha[branch] > 1:
raise ValueError("Branch alpha values must be between zero and one!")
elif np.isscalar(branch_alpha):
scalarcheck(branch_alpha,'branch_alpha',bounds=[0,1])
bw = branch_alpha
branch_alpha = {}
for branch in branches.keys():
branch_alpha[branch] = bw
else:
raise TypeError("Branch alpha values have to be provided either as scalar or dictionary!")
# Check if line-widths for edges have been provided, otherwise assign default values
if np.isscalar(edge_lw):
scalarcheck(edge_lw,'edge_lw',bounds=[0,np.inf])
edge_lw = np.ones(connections.shape) * edge_lw
else:
arrcheck(edge_lw,'matrix','edge_lw',bounds=[0,np.inf])
if edge_lw.shape != connections.shape:
raise ValueError("Edge line-widths have to be provided in the same format as connection array!")
# Check if alpha values for edges have been provided, otherwise assign default values
if np.isscalar(edge_alpha):
scalarcheck(edge_alpha,'edge_alpha',bounds=[0,1])
edge_alpha = np.ones(connections.shape) * edge_alpha
else:
arrcheck(edge_alpha,'matrix','edge_alpha',bounds=[0,1])
if np.any(edge_alpha.shape != connections.shape):
raise ValueError("Edge alpha values have to be provided in the same format as connection array!")
# Check if an intial setting for the arch radian was provided, otherwise use the default
if np.isscalar(radians):
scalarcheck(radians,'radians')
else:
arrcheck(radians,'matrix','radians')
if rsh[0] != num_branches:
raise ValueError("Arch radians must be provided as square array matching no. of branches!!")
# Prepare axis
ax.set_aspect('equal')
ax.hold(True)
# If nodes have to be rendered as spheres, some tuning is required...
if nodes3d or full3d:
# Turn on 3d projection if nodes are to be rendered as spheres
bgcol = ax.get_axis_bgcolor()
ax = plt.gca(projection='3d',axisbg=bgcol)
ax.hold(True)
if not full3d:
ax.view_init(azim=-90,elev=90)
else:
ax.view_init(azim=viewpoint[0],elev=viewpoint[1])
# Turn off 3D grid and change background of panes (or not)
if not show_grid:
ax.grid(False)
ax.w_xaxis.set_pane_color(colorConverter.to_rgb(bgcol))
ax.w_yaxis.set_pane_color(colorConverter.to_rgb(bgcol))
ax.w_zaxis.set_pane_color(colorConverter.to_rgb(bgcol))
# Turn off all axes highlights
ax.zaxis.line.set_lw(0)
ax.set_zticks([])
ax.xaxis.line.set_lw(0)
ax.set_xticks([])
ax.yaxis.line.set_lw(0)
ax.set_yticks([])
# Generate surface data for the prototype nodal sphere
theta = np.arange(-sphere_res,sphere_res+1,2)/sphere_res*np.pi
phi = np.arange(-sphere_res,sphere_res+1,2)/sphere_res*np.pi/2
cosphi = np.cos(phi); cosphi[0] = 0; cosphi[-1] = 0
sinth = np.sin(theta); sinth[0] = 0; sinth[-1] = 0
xsurf = np.outer(cosphi,np.cos(theta))
ysurf = np.outer(cosphi,sinth)
zsurf = np.outer(np.sin(phi),np.ones((sphere_res+1,)))
# If virtual lighting is wanted, create a light source for illumination
if lightsource is not None:
light = LightSource(*lightsource)
rgb_arr = np.ones((zsurf.shape[0],zsurf.shape[1],3))
# Start by truncating color-values based on vrange limits that were provided
if np.any([0,1] != node_vrange):
node_cmap = plt.cm.ScalarMappable(norm=Normalize(node_vrange[0],node_vrange[1]),cmap=node_cmap).to_rgba
if np.any([0,1] != edge_vrange):
edge_cmap = plt.cm.ScalarMappable(norm=Normalize(edge_vrange[0],edge_vrange[1]),cmap=edge_cmap).to_rgba
# Plot branches and construct nodal patches (we do this no matter if we're 3-dimensional or not)
node_patches = {}
branch_dvecs = {}
branch_kwargs = {'lw':-1, 'color': -np.ones((3,)), 'alpha': -1, 'zorder':1}
for branch in branch_arr:
# Compute normed directional vector of branch
if full3d:
azim,elev = angle[branch]
bdry = branch_extent[branch][1]*np.array([np.sin(elev)*np.cos(azim),np.sin(elev)*np.sin(azim),np.cos(elev)])
else:
bdry = branch_extent[branch][1]*np.array([np.cos(angle[branch]),np.sin(angle[branch])])
vec = bdry - center
vec /= np.linalg.norm(vec)
bstart = center + branch_extent[branch][0]*vec
branch_dvecs[branch] = vec
# Plot branch as straight line
branch_kwargs['lw'] = branch_lw[branch]
branch_kwargs['color'] = branch_colors[branch]
branch_kwargs['alpha'] = branch_alpha[branch]
if full3d:
plt.plot([bstart[0],bdry[0]],[bstart[1],bdry[1]],zs=[bstart[2],bdry[2]],zdir='z',**branch_kwargs)
elif nodes3d:
plt.plot([bstart[0],bdry[0]],[bstart[1],bdry[1]],zs=0,zdir='z',**branch_kwargs)
else:
plt.plot([bstart[0],bdry[0]],[bstart[1],bdry[1]],**branch_kwargs)
# Now construct circular patches for all nodes and save'em in the `patch_list` list (and the `node_patch` dict)
patch_list = []
for node in xrange(branches[branch].size):
pos = center + vec*positions[branch][node]
patch_list.append(Circle(pos,radius=node_sizes[branch][node],\
facecolor=node_cmap(node_vals[branch][node]),\
alpha=node_alpha[branch][node],\
lw=node_lw,\
zorder=3))
node_patches[branch] = patch_list
# Determine if our network is directed or not
sym = issym(connections)
# Allocate dicionary for all edge-related parameters
edge_kwargs = {'connectionstyle':'a string','lw': -1, 'alpha': -1, 'color': -np.ones((3,)), 'zorder': 2}
if sym:
edge_kwargs['arrowstyle'] = '-'
else:
edge_kwargs['arrowstyle'] = '-|>'
# 3D is again the special snowflake, so do this nonsense separately...
if full3d:
# In a fully three-dimensional environment, we can't go 'round the tree to plot edges - everything may be connected
seen = []
for br, branch in enumerate(branch_arr):
seen.append(branch)
neighbors = np.setdiff1d(branch_arr,seen)
for twig in neighbors:
if np.isscalar(radians):
br_vec = branch_dvecs[branch]
tw_vec = branch_dvecs[twig]
ang_bt = np.arctan2(np.linalg.norm(np.cross(br_vec,tw_vec)),br_vec.dot(tw_vec))
ang_bt += 2*np.pi*(ang_bt >= 0)
rad = (-1)**(ang_bt > np.pi)*radians
else:
rad = radians[br,np.where(branch_arr==twig)[0][0]]
for n1,node1 in enumerate(branches[branch]):
for n2,node2 in enumerate(branches[twig]):
edge_kwargs['connectionstyle'] = 'arc3,rad=%s'%rad
edge_kwargs['lw'] = edge_lw[node1,node2]
edge_kwargs['alpha'] = edge_alpha[node1,node2]
edge_kwargs['color'] = edge_cmap(connections[node1,node2])
xcoords = [node_patches[branch][n1].center[0],node_patches[twig][n2].center[0]]
ycoords = [node_patches[branch][n1].center[1],node_patches[twig][n2].center[1]]
zcoords = [node_patches[branch][n1].center[2],node_patches[twig][n2].center[2]]
if sym:
if connections[node1,node2] > ethresh:
ax.add_artist(Arrow3D(xcoords,ycoords,zcoords,**edge_kwargs))
else:
if connections[node1,node2] > ethresh:
ax.add_artist(Arrow3D(xcoords,ycoords,zcoords,**edge_kwargs))
if connections[node2,node1] > ethresh:
rad = - rad
edge_kwargs['connectionstyle'] = 'arc3,rad=%s'%rad
ax.add_artist(Arrow3D(xcoords[::-1],ycoords[::-1],zcoords[::-1],**edge_kwargs))
# 2D rendering of edges is a lot easier (just go branch by branch)
else:
for br, branch in enumerate(branch_arr):
if br < branch_arr.size-1:
twig = branch_arr[br+1]
else:
twig = branch_arr[0]
if np.isscalar(radians):
br_vec = branch_dvecs[branch]
tw_vec = branch_dvecs[twig]
ang_bt = np.arctan2(tw_vec[1],tw_vec[0]) - np.arctan2(br_vec[1],br_vec[0])
ang_bt += 2*np.pi*(ang_bt < 0)
rad = (-1)**(ang_bt > np.pi)*radians
else:
rad = radians[br,br+1]
for n1,node1 in enumerate(branches[branch]):
for n2,node2 in enumerate(branches[twig]):
edge_kwargs['connectionstyle'] = 'arc3,rad=%s'%rad
edge_kwargs['lw'] = edge_lw[node1,node2]
edge_kwargs['alpha'] = edge_alpha[node1,node2]
edge_kwargs['color'] = edge_cmap(connections[node1,node2])
xcoords = [node_patches[branch][n1].center[0],node_patches[twig][n2].center[0]]
ycoords = [node_patches[branch][n1].center[1],node_patches[twig][n2].center[1]]
if sym:
if connections[node1,node2] > ethresh:
if nodes3d:
ax.add_artist(Arrow3D(xcoords,ycoords,[0,0],**edge_kwargs))
else:
ax.add_patch(FancyArrowPatch(node_patches[branch][n1].center,\
node_patches[twig][n2].center,\
**edge_kwargs))
else:
if connections[node1,node2] > ethresh:
if nodes3d:
ax.add_artist(Arrow3D(xcoords,ycoords,[0,0],**edge_kwargs))
else:
ax.add_patch(FancyArrowPatch(node_patches[branch][n1].center,\
node_patches[twig][n2].center,\
**edge_kwargs))
if connections[node2,node1] > ethresh:
rad = - rad
edge_kwargs['connectionstyle'] = 'arc3,rad=%s'%rad
if nodes3d:
ax.add_artist(Arrow3D(xcoords[::-1],ycoords[::-1],[0,0],**edge_kwargs))
else:
ax.add_patch(FancyArrowPatch(node_patches[twig][n2].center,\
node_patches[branch][n1].center,\
**edge_kwargs))
# Finally, draw nodes and compute maximal extent of branches
top = -np.inf
bot = np.inf
lft = np.inf
rgt = -np.inf
up = -np.inf
lo = np.inf
lbl_kwargs = {'fontsize':labelsize,'ha':'center','va':'center'}
nd_kwargs = {'cstride':1,'rstride':1,'linewidth':0,'antialiased':False,'alpha':-1,'zorder':-1}
zcord = 0
for branch in branch_arr:
branch_tvec = branch_extent[branch][1]*branch_dvecs[branch]
top = np.max([top,branch_tvec[1]])
bot = np.min([bot,branch_tvec[1]])
lft = np.min([lft,branch_tvec[0]])
rgt = np.max([rgt,branch_tvec[0]])
if full3d:
up = np.max([up,branch_tvec[2]])
lo = np.min([lo,branch_tvec[2]])
for node in xrange(branches[branch].size):
if nodes3d or full3d:
circ = node_patches[branch][node]
if full3d:
zcord = circ.center[2]
nd_kwargs['alpha'] = circ.get_alpha()
nd_kwargs['zorder'] = circ.get_zorder()
if lightsource is not None:
nd_kwargs['facecolors'] = light.shade_rgb(rgb_arr*np.array(circ.get_facecolor()[:-1]),zsurf)
else:
nd_kwargs['color'] = circ.get_facecolor()
ax.plot_surface(circ.get_radius()*xsurf + circ.center[0],\
circ.get_radius()*ysurf + circ.center[1],\
circ.get_radius()*zsurf + zcord,\
**nd_kwargs)
if labels is not None:
if nodes3d:
lcord = 1.5*circ.get_radius()
else:
lcord = node_patches[branch][node].center[2]
ax.text(node_patches[branch][node].center[0],\
node_patches[branch][node].center[1],\
lcord,\
labels[branch][node],**lbl_kwargs)
else:
ax.add_patch(node_patches[branch][node])
if labels is not None:
ax.text(node_patches[branch][node].center[0],\
node_patches[branch][node].center[1],\
labels[branch][node],**lbl_kwargs)
# Set axes limits based on extent of branches
x_width = rgt - lft
y_heght = top - bot
ax.set_xlim(left=lft-x_offset*x_width,right=rgt+x_offset*x_width)
ax.set_ylim(bottom=bot-y_offset*y_heght,top=top+y_offset*y_heght)
if full3d:
z_len = up - lo
ax.set_zlim(bottom=lo-z_offset*z_len,top=up+z_offset*z_len)
# Draw the beauty and get the hell out of here
plt.draw()
if full3d or nodes3d: plt.axis('equal')
##########################################################################################
[docs]def nw_zip(ntw):
"""
Convert the upper triangular portion of a real symmetric matrix to a vector and vice versa
Parameters
----------
ntw : NumPy 1d/2d/3d array
Array representing either (a) a vector (1d array) or (b) a set of vectors (2d array)
holding the upper triangular part of a symmetric matrix in column-wise order or
(c) a matrix (2d array) or (d) a rank 3 tensor comprising a cohort of symmetric
matrices whose upper triangular entries will be extracted.
If `ntw` is a 2d array, it may either represent a symmetric matrix for compression
or an array of column vectors. Specifically, a `K` by `M` array will be reconstructed
to form a `N`-by-`N`-by`M` array of symmetric `N`-by-`N` matrices, where
`K = N * (N - 1) / 2`. Conversely, if the input is a 3d-array, its format is
assumed to be `N`-by-`N`-by-`M`. See Notes below for details.
Returns
-------
nws : NumPy 1d/2d/3d array
Depending on the input, the returned array is either a compressed representation of
the symmetric input matrix/matrices or a full matrix/tensor reconstructed from
the provided upper triangular values.
Notes
-----
Note, that this routine does NOT consider diagonal elements, i.e., only off-diagonal
entries will be stored/reconstructed. By design, entries are assumed to be ordered
column-wise. Note further, that a symmetric `N`-by-`N` matrix contains `K = N * (N - 1) / 2`
upper triangular elements. Thus, if `N = 3` then `K = 3` so that any symmetric 3-by-3
array can either represent a symmetric matrix or a set of three 3-element vectors
representing upper triangular entries of three different symmetric matrices. In this case,
the routine always assumes that the input represents a symmetric matrix and
prints a warning message.
Examples
--------
Consider the symmetric matrix `mat`
>>> mat
array([[ 0., 1., 2., 3.],
[ 1., 0., 4., 5.],
[ 2., 4., 0., 6.],
[ 3., 5., 6., 0.]])
Using `nw_zip` to compress `mat` yields the array `vec`
>>> import nws_tools as nwt
>>> vec = nwt.nw_zip(mat)
>>> vec
array([ 1., 2., 3., 4., 5., 6.])
Now reconstruct `mat` from `vec`
>>> nwt.nw_zip(vec)
array([[ 0., 1., 2., 3.],
[ 1., 0., 4., 5.],
[ 2., 4., 0., 6.],
[ 3., 5., 6., 0.]])
Consider a second symmetric matrix `mat2`
>>> mat2
array([[ 0., 7., 8., 9.],
[ 7., 0., 10., 11.],
[ 8., 10., 0., 12.],
[ 9., 11., 12., 0.]])
Now, collect `mat` and `mat2` in a tensor and use `nw_zip` to compress it
>>> mats = np.zeros((4,4,2))
>>> mats[:,:,0] = mat
>>> mats[:,:,1] = mat2
>>> vecs = nw_zip(mats)
>>> vecs
array([[ 1., 7.],
[ 2., 8.],
[ 3., 9.],
[ 4., 10.],
[ 5., 11.],
[ 6., 12.]])
Uncompressing `vecs` yields `mats` again
>>> (mats == nw_zip(vecs)).min()
True
See also
--------
None
"""
# Sanity checks (we can't use `arrcheck` here, because `ntw` can have varying funky dimensions...)
try:
stw = ntw.squeeze().shape
except:
raise TypeError('Input network must be a NumPy array, not '+type(ntw).__name__+'!')
if len(stw) > 3:
raise ValueError('Input network must not have more than 3 dimensions!')
if not np.issubdtype(ntw.dtype, np.number) or not np.isreal(ntw).all():
raise TypeError('Input network must be a real-valued NumPy array!')
if np.isfinite(ntw).min() == False:
raise ValueError('Input network must be a real valued NumPy array without Infs or NaNs!')
# If we're dealing with a 1d-array, convert it to a `K`-by-1 array to hit two birds with one stone below
if len(stw) == 1:
ntw = ntw.reshape(ntw.size,1)
stw = ntw.shape
# Now let's do what we're here for: everything below assumes `N` is the dimension of the
# considered symmetric matrix and `K` denotes the number of elements in its upper triangular
# portion (excluding the main diagonal), i.e., `K = N * (N-1) / 2`
# Start with the ambiguous case: input can be a single vector, an array of vectors or a connectivity matrix
if len(stw) == 2:
# Input is square so it's a symmetric matrix. It's only ambiguous in case `N = 3` (then `K == N`),
# but why would you go through all this trouble for a 3-by-3 matrix???
if stw[0] == stw[1]:
if issym(ntw):
N = stw[0]
if N == 3:
print "WARNING: Assuming input array is a symmetric 3-by-3 matrix..."
msk = np.triu(np.ones((N,N),dtype=bool),1)
return ntw[msk]
else:
raise ValueError("Input matrix is not symmetric and thus cannot be compressed!")
# Input is an array of `M` vectors of length `K` each holding the entries of a symmetric `N`-by-`N` matrix
# (note that `M = 1` is possible as well)
else:
# First step: compute the original dimension `N` of the symmetric matrix by solving the
# quadratic equation `N*(N-1)/2 = K`, where `K` is the length of the given input vector(s).
# We're only interested in the positive solution here. To account for roundoff errors,
# use `np.round`, but double-check the result
K = stw[0]
M = stw[1]
N = int(np.round(0.5 + np.sqrt(0.5 + 2*K)))
if N*(N - 1)/2 != K:
raise ValueError("Provided values cannot be arranged in a triangular layout!")
msk = np.triu(np.ones((N,N),dtype=bool),1)
nw = np.zeros((N,N))
nws = np.zeros((N,N,M))
for m in xrange(M):
nw[:] = 0
nw[msk] = ntw[:,m]
nw += nw.T
nws[:,:,m] = nw.copy()
return nws.squeeze() # If `M = 1` don't return a `N`-by-`N`-by-1 array
# The input is a rank 3 tensor of dimension `(N,N,M)` holding `M` symmetric `N`-by-`N` matrices
else:
# No ambiguity here...
if stw[0] != stw[1]:
raise ValueError("First and second dimension of input tensor must be identical!")
N = stw[0]
M = stw[2]
msk = np.triu(np.ones((N,N),dtype=bool),1)
nws = np.zeros((msk.sum(),M))
for m in xrange(M):
nw = ntw[:,:,m]
if issym(nw):
nws[:,m] = nw[msk]
else:
raise ValueError("Input matrix no. "+str(m)+" is not symmetric!")
return nws
##########################################################################################
def arrcheck(arr,kind,varname,bounds=None):
"""
Local helper function performing sanity checks on arrays (1d/2d/3d)
"""
if not isinstance(arr,np.ndarray):
raise TypeError('Input `'+varname+'` must be a NumPy array, not '+type(arr).__name__+'!')
sha = arr.shape
if kind == 'tensor':
if len(sha) != 3:
raise ValueError('Input `'+varname+'` must be a `N`-by-`N`-by-`k` NumPy array')
if (min(sha[0],sha[1])==1) or (sha[0]!=sha[1]):
raise ValueError('Input `'+varname+'` must be a `N`-by-`N`-by-`k` NumPy array!')
dim_msg = '`N`-by-`N`-by-`k`'
elif kind == 'matrix':
if len(sha) != 2:
raise ValueError('Input `'+varname+'` must be a `N`-by-`N` NumPy array')
if (min(sha)==1) or (sha[0]!=sha[1]):
raise ValueError('Input `'+varname+'` must be a `N`-by-`N` NumPy array!')
dim_msg = '`N`-by-`N`'
elif kind == 'vector':
sha = arr.squeeze().shape
if len(sha) != 1:
raise ValueError('Input `'+varname+'` must be a NumPy 1darray')
if sha[0] <= 1:
raise ValueError('Input `'+varname+'` must be a NumPy 1darray of length `N`!')
dim_msg = ''
else:
print "Error checking could not be performed - something's wrong here..."
if not np.issubdtype(arr.dtype, np.number) or not np.isreal(arr).all():
raise ValueError('Input `'+varname+'` must be a real-valued '+dim_msg+' NumPy array!')
if np.isfinite(arr).min() == False:
raise ValueError('Input `'+varname+'` must be a real-valued NumPy array without Infs or NaNs!')
if bounds is not None:
if arr.min() < bounds[0] or arr.max() > bounds[1]:
raise ValueError("Values of input array `"+varname+"` must be between "+str(bounds[0])+\
" and "+str(bounds[1])+"!")
##########################################################################################
def scalarcheck(val,varname,kind=None,bounds=None):
"""
Local helper function performing sanity checks on scalars
"""
if not np.isscalar(val) or not plt.is_numlike(val):
raise TypeError("Input `"+varname+"` must be a scalar!")
if not np.isfinite(val) or not np.isreal(val):
raise ValueError("Input `"+varname+"` must be real and finite!")
if kind == 'int':
if (round(val) != val):
raise ValueError("Input `"+varname+"` must be an integer!")
if bounds is not None:
if val < bounds[0] or val > bounds[1]:
raise ValueError("Input scalar `"+varname+"` must be between "+str(bounds[0])+" and "+str(bounds[1])+"!")