Source code for prospect.viewer.cds

# Licensed under a 3-clause BSD style license - see LICENSE.rst
# -*- coding: utf-8 -*-
"""
===================
prospect.viewer.cds
===================

Class containing all bokeh's ColumnDataSource objects needed in viewer.py

"""

import numpy as np
from pkg_resources import resource_filename
from astropy.io import fits
from astropy.table import Table

import bokeh.plotting as bk
from bokeh.models import ColumnDataSource

_specutils_imported = True
try:
    from specutils import Spectrum1D, SpectrumList
except ImportError:
    _specutils_imported = False

_desispec_imported = True
try:
    from desispec.interpolation import resample_flux
except ImportError:
    _desispec_imported = False

from ..coaddcam import coaddcam_prospect
from ..utilities import supported_desitarget_masks, vi_file_fields, load_redrock_templates


[docs]def _airtovac(w): """Convert air wavelengths to vacuum wavelengths. Don't convert less than 2000 Å. Parameters ---------- w : :class:`float` Wavelength [Å] of the line in air. Returns ------- :class:`float` Wavelength [Å] of the line in vacuum. """ if w < 2000.0: return w; vac = w for iter in range(2): sigma2 = (1.0e4/vac)*(1.0e4/vac) fact = 1.0 + 5.792105e-2/(238.0185 - sigma2) + 1.67917e-3/(57.362 - sigma2) vac = w*fact return vac
[docs]class ViewerCDS(object): """ Encapsulates Bokeh ColumnDataSource objects to be passed to js callback functions. """ def __init__(self): self.cds_spectra = None self.cds_median_spectra = None self.cds_coaddcam_spec = None self.cds_model = None self.cds_model_2ndfit = None self.cds_othermodel = None self.cds_metadata = None self.cds_spectral_lines = None self.dict_fit_templates = None # Special case: not a CDS self.dict_std_templates = None # Special case: not a CDS self.dict_rrdetails = None # Special case: not a CDS
[docs] def load_spectra(self, spectra, with_noise=True): """ Creates column data source for observed spectra """ self.cds_spectra = list() is_desispec = False if _specutils_imported and isinstance(spectra, SpectrumList): s = spectra bands = spectra.bands elif _specutils_imported and isinstance(spectra, Spectrum1D): s = [spectra] bands = ['coadd'] else : # Assume desispec Spectra obj is_desispec = True s = spectra bands = spectra.bands for j, band in enumerate(bands): input_wave = s.wave[band] if is_desispec else s[j].spectral_axis.value input_nspec = spectra.num_spectra() if is_desispec else s[j].flux.shape[0] cdsdata = dict( origwave = input_wave.copy(), plotwave = input_wave.copy(), ) for i in range(input_nspec): key = 'origflux'+str(i) input_flux = spectra.flux[band][i] if is_desispec else s[j].flux.value[i, :] cdsdata[key] = input_flux.copy() if with_noise : key = 'orignoise'+str(i) input_ivar = spectra.ivar[band][i] if is_desispec else s[j].uncertainty.array[i, :] noise = np.zeros(len(input_ivar)) w, = np.where( (input_ivar > 0) ) noise[w] = 1/np.sqrt(input_ivar[w]) cdsdata[key] = noise cdsdata['plotflux'] = cdsdata['origflux0'] if with_noise : cdsdata['plotnoise'] = cdsdata['orignoise0'] self.cds_spectra.append( ColumnDataSource(cdsdata, name=band) )
[docs] def compute_median_spectra(self, spectra): """ Stores the median value for each spectrum into CDS. Simple concatenation of all values from different bands. """ cdsdata = dict(median=[]) is_desispec = False if _specutils_imported and isinstance(spectra, SpectrumList): s = spectra bands = spectra.bands nspec = spectra[0].flux.shape[0] elif _specutils_imported and isinstance(spectra, Spectrum1D): s = [spectra] bands = ['coadd'] nspec = spectra.flux.shape[0] else : # Assume desispec Spectra obj is_desispec = True s = spectra bands = spectra.bands nspec = spectra.num_spectra() for i in range(nspec): if is_desispec: flux_array = np.concatenate( tuple([s.flux[band][i] for band in bands]) ) else: flux_array = np.concatenate( tuple([s[j].flux[i, :].value for j, band in enumerate(bands)]) ) w, = np.where( ~np.isnan(flux_array) ) if len(w)==0 : cdsdata['median'].append(1) else : cdsdata['median'].append(np.median(flux_array[w])) self.cds_median_spectra = ColumnDataSource(cdsdata)
[docs] def init_coaddcam_spec(self, spectra, with_noise=True): """ Creates column data source for camera-coadded observed spectra Do NOT store all coadded spectra in CDS obj, to reduce size of html files Except for the first spectrum, coaddition is done later in javascript """ coadd_wave, coadd_flux, coadd_ivar = coaddcam_prospect(spectra) cds_coaddcam_data = dict( origwave = coadd_wave.copy(), plotwave = coadd_wave.copy(), plotflux = coadd_flux[0,:].copy(), plotnoise = np.ones(len(coadd_wave)) ) if with_noise : w, = np.where( (coadd_ivar[0,:] > 0) ) cds_coaddcam_data['plotnoise'][w] = 1/np.sqrt(coadd_ivar[0,:][w]) self.cds_coaddcam_spec = ColumnDataSource(cds_coaddcam_data)
[docs] def init_model(self, model, second_fit=False): """ Creates a CDS for model spectrum """ mwave, mflux = model cdsdata = dict( origwave = mwave.copy(), plotwave = mwave.copy(), plotflux = np.zeros(len(mwave)), ) for i in range(len(mflux)): key = 'origflux'+str(i) cdsdata[key] = mflux[i] cdsdata['plotflux'] = cdsdata['origflux0'] if second_fit: self.cds_model_2ndfit = ColumnDataSource(cdsdata) else: self.cds_model = ColumnDataSource(cdsdata)
[docs] def init_othermodel(self, zcatalog): """ Initialize CDS for the 'other model' curve, from the best fit """ self.cds_othermodel = ColumnDataSource({ 'plotwave' : self.cds_model.data['plotwave'], 'origwave' : self.cds_model.data['origwave'], 'origflux' : self.cds_model.data['origflux0'], 'plotflux' : self.cds_model.data['origflux0'], 'zref' : zcatalog['Z'][0]+np.zeros(len(self.cds_model.data['origflux0'])) # Track z reference in model })
[docs] def load_fit_templates(self, template_dir=None, nbpts_templates=4000): """ Create dict for spectral templates used in Redrock fits. These are used to recompute Redrock's Nth best-fit spectra on-the-fly in javascript. Templates are resampled in order to limit the size of html pages (and the browser's CPU usage). This resampling is dictated by parameter nbpts_templates. """ assert _desispec_imported # for resample_flux rr_templts = load_redrock_templates(template_dir=template_dir) self.dict_fit_templates = dict() for key,templt in rr_templts.items(): fulltype_key = "_".join(key) # merge redrock's (TYPE, SUBTYPE) wave_array = np.linspace(templt.wave[0], templt.wave[-1], num=nbpts_templates) flux_array = np.zeros(( templt.flux.shape[0],len(wave_array) )) for i in range(templt.flux.shape[0]): flux_array[i,:] = resample_flux(wave_array, templt.wave, templt.flux[i,:]) self.dict_fit_templates["wave_"+fulltype_key] = wave_array self.dict_fit_templates["flux_"+fulltype_key] = flux_array
[docs] def load_std_templates(self, std_template_file=None): """ Load a dict of "standard" templates. The std template file is `data/std_templates.fits`. It was created from `../scripts/prospect_std_templates.py`. """ self.dict_std_templates = dict() if std_template_file is None: std_template_file = resource_filename('prospect', "data/std_templates.fits") hdul = fits.open(std_template_file) nhdu = len(hdul) hdul.close() for i in range(1, nhdu): t = Table.read(std_template_file, hdu=i) for key in t.keys(): #- check table column name: if key[:5] not in ['wave_', 'flux_']: raise ValueError('STD template file: wrong column name ('+key+')') #- check wavelength array is regularly, linearly binned (with absolute tolerance 0.01 AA): if key[:5]=='wave_': waves = np.array(t[key]) delta_waves = waves[1:] - waves[:-1] if not np.allclose(delta_waves, delta_waves[0], atol=0.01, rtol=1.e-10): raise ValueError('STD template file: found irregular wavelength binning ('+key+')') self.dict_std_templates[key] = np.array(t[key]) #- initialize cds_othermodel, if this was not done yet: if self.cds_othermodel is None: key_zero = list(self.dict_std_templates.keys())[0][5:] self.cds_othermodel = ColumnDataSource({ 'plotwave' : self.dict_std_templates['wave_'+key_zero], 'origwave' : self.dict_std_templates['wave_'+key_zero], 'origflux' : self.dict_std_templates['flux_'+key_zero], 'plotflux' : self.dict_std_templates['flux_'+key_zero], 'zref' : np.zeros(len(self.dict_std_templates['flux_'+key_zero])) # std templates have z=0 })
[docs] def load_rrdetails(self, redrock_cat): """ Create dict for detailled redrock outputs. Used to recompute redrock's Nth best fit spectra on-the-fly in javascript, and display them in a table. """ self.dict_rrdetails = dict() for key in redrock_cat.keys() : self.dict_rrdetails[key] = np.asarray(redrock_cat[key]) self.dict_rrdetails['Nfit'] = redrock_cat['Z'].shape[1]
[docs] def load_metadata(self, spectra, mask_type=None, zcatalog=None, survey='DESI'): """ Creates column data source for target-related metadata, from fibermap, zcatalog and VI files """ if survey == 'DESI': nspec = spectra.num_spectra() # Optional metadata: fibermap_keys = ['HPXPIXEL', 'MORPHTYPE', 'CAMERA', 'COADD_NUMEXP', 'COADD_EXPTIME', 'COADD_NUMNIGHT', 'COADD_NUMTILE'] # Optional metadata, will check matching FIRST/LAST/NUM keys in fibermap: special_fm_keys = ['FIBER', 'NIGHT', 'EXPID', 'TILEID'] # Mandatory keys if zcatalog is set: self.zcat_keys = ['Z', 'SPECTYPE', 'SUBTYPE', 'ZERR', 'ZWARN', 'DELTACHI2'] # Mandatory metadata: self.phot_bands = ['G','R','Z', 'W1', 'W2'] supported_masks = supported_desitarget_masks # Galactic extinction coefficients: # - Wise bands from https://github.com/dstndstn/tractor/blob/master/tractor/sfd.py # - Other bands from desiutil.dust (updated coefficients Apr 2021, # matching https://desi.lbl.gov/trac/wiki/ImagingStandardBandpass) R_extinction = {'W1':0.184, 'W2':0.113, 'W3':0.0241, 'W4':0.00910, 'G_N':3.258, 'R_N':2.176, 'Z_N':1.199, 'G_S':3.212, 'R_S':2.164, 'Z_S':1.211} elif survey == 'SDSS': nspec = spectra.flux.shape[0] # Mandatory keys if zcatalog is set: self.zcat_keys = ['Z', 'CLASS', 'SUBCLASS', 'Z_ERR', 'ZWARNING', 'RCHI2DIFF'] # Mandatory metadata: self.phot_bands = ['u', 'g', 'r', 'i', 'z'] supported_masks = ['PRIMTARGET', 'SECTARGET', 'BOSS_TARGET1', 'BOSS_TARGET2', 'ANCILLARY_TARGET1', 'ANCILLARY_TARGET2', 'EBOSS_TARGET0', 'EBOSS_TARGET1', 'EBOSS_TARGET2'] else: raise ValueError('Wrong survey') self.cds_metadata = ColumnDataSource() #- Generic metadata if survey == 'DESI': #- Special case for targetids: No int64 in js !! self.cds_metadata.add([str(x) for x in spectra.fibermap['TARGETID']], name='TARGETID') #- "Special" keys: check for FIRST/LAST/NUM for fm_key in special_fm_keys: use_first_last_num = False if all([ (x+fm_key in spectra.fibermap.keys()) for x in ['FIRST_','LAST_','NUM_'] ]): if np.any(spectra.fibermap['NUM_'+fm_key] > 1) : # if NUM==1, use fm_key only use_first_last_num = True self.cds_metadata.add(spectra.fibermap['FIRST_'+fm_key], name='FIRST_'+fm_key) self.cds_metadata.add(spectra.fibermap['LAST_'+fm_key], name='LAST_'+fm_key) self.cds_metadata.add(spectra.fibermap['NUM_'+fm_key], name='NUM_'+fm_key) if (not use_first_last_num) and fm_key in spectra.fibermap.keys(): # Do not load placeholder metadata: if not (np.all(spectra.fibermap[fm_key]==0) or np.all(spectra.fibermap[fm_key]==-1)): self.cds_metadata.add(spectra.fibermap[fm_key], name=fm_key) #- "Normal" keys for fm_key in fibermap_keys: # Arbitrary choice: if fm_key == 'COADD_NUMEXP' and 'NUM_EXPID' in self.cds_metadata.data.keys(): continue if fm_key == 'COADD_NUMNIGHT' and 'NUM_NIGHT' in self.cds_metadata.data.keys(): continue if fm_key == 'COADD_NUMTILE' and 'NUM_TILEID' in self.cds_metadata.data.keys(): continue if fm_key in spectra.fibermap.keys(): if not (np.all(spectra.fibermap[fm_key]==0) or np.all(spectra.fibermap[fm_key]==-1)): self.cds_metadata.add(spectra.fibermap[fm_key], name=fm_key) elif survey == 'SDSS': #- Set 'TARGETID' name to OBJID for convenience self.cds_metadata.add([str(x.tolist()) for x in spectra.meta['plugmap']['OBJID']], name='TARGETID') #- Photometry for i, bandname in enumerate(self.phot_bands) : if survey == 'SDSS': mag = spectra.meta['plugmap']['MAG'][:, i] else : mag = np.zeros(nspec) flux = spectra.fibermap['FLUX_'+bandname] extinction = np.ones(len(flux)) if ('MW_TRANSMISSION_'+bandname) in spectra.fibermap.keys(): extinction = spectra.fibermap['MW_TRANSMISSION_'+bandname] elif ('EBV' in spectra.fibermap.keys()) and (bandname.upper() in ['W1','W2','W3','W4']): extinction = 10**(- R_extinction[bandname.upper()] * spectra.fibermap['EBV']) elif all(x in spectra.fibermap.keys() for x in ['EBV','PHOTSYS']) and (bandname.upper() in ['G','R','Z']): for photsys in ['N', 'S']: wphot, = np.where(spectra.fibermap['PHOTSYS'] == photsys) a_band = R_extinction[bandname.upper()+"_"+photsys] * spectra.fibermap['EBV'][wphot] extinction[wphot] = 10**(-a_band / 2.5) w, = np.where( (flux>0) & (extinction>0) ) mag[w] = -2.5*np.log10(flux[w]/extinction[w])+22.5 self.cds_metadata.add(mag, name='mag_'+bandname) #- Targeting masks if mask_type is not None: if survey == 'DESI': if mask_type not in spectra.fibermap.keys(): mask_candidates = [x for x in spectra.fibermap.keys() if '_TARGET' in x] raise ValueError(mask_type+" is not in spectra.fibermap.\n Hints of available masks: "+(' '.join(mask_candidates))) mask_used = supported_masks[mask_type] target_bits = spectra.fibermap[mask_type] target_info = [ ' '.join(mask_used.names(x)) for x in target_bits ] elif survey == 'SDSS': assert mask_type in supported_masks target_info = [ mask_type + ' (DUMMY)' for x in spectra.meta['plugmap'] ] # placeholder self.cds_metadata.add(target_info, name='Targeting masks') #- Software versions #- TODO : get template version (from zcatalog...) if survey == 'SDSS': spec_version = 'SDSS' else : spec_version = '0' for xx,yy in spectra.meta.items() : if yy=="desispec" : spec_version = spectra.meta[xx.replace('NAM','VER')] self.cds_metadata.add([spec_version for i in range(nspec)], name='spec_version') redrock_version = ["-1" for i in range(nspec)] if zcatalog is not None: if 'RRVER' in zcatalog.keys(): redrock_version = zcatalog['RRVER'].data self.cds_metadata.add(redrock_version, name='redrock_version') self.cds_metadata.add(np.zeros(nspec)-1, name='template_version') #- Redshift fit if zcatalog is not None: for zcat_key in self.zcat_keys: if 'TYPE' in zcat_key or 'CLASS' in zcat_key: data = zcatalog[zcat_key].astype('U{0:d}'.format(zcatalog[zcat_key].dtype.itemsize)) else : data = zcatalog[zcat_key] self.cds_metadata.add(data, name=zcat_key) #- VI informations default_vi_info = [ (x[1],x[3]) for x in vi_file_fields if x[0][0:3]=="VI_" ] for vi_key, vi_value in default_vi_info: self.cds_metadata.add([vi_value for i in range(nspec)], name=vi_key)
def load_spectral_lines(self, z=0): line_data = dict( restwave = [], plotwave = [], name = [], longname = [], plotname = [], emission = [], major = [], #y = [] ) for line_category in ('emission', 'absorption'): # encoding=utf-8 is needed to read greek letters line_array = np.genfromtxt(resource_filename('prospect', "data/{0}_lines.txt".format(line_category)), delimiter=",", dtype=[("name", "|U20"), ("longname", "|U20"), ("wavelength", float), ("vacuum", bool), ("major", bool)], encoding='utf-8') vacuum_wavelengths = line_array['wavelength'] w, = np.where(line_array['vacuum']==False) vacuum_wavelengths[w] = np.array([_airtovac(wave) for wave in line_array['wavelength'][w]]) line_data['restwave'].extend(vacuum_wavelengths) line_data['plotwave'].extend(vacuum_wavelengths * (1+z)) line_data['name'].extend(line_array['name']) line_data['longname'].extend(line_array['longname']) line_data['plotname'].extend(line_array['name']) emission_flag = True if line_category=='emission' else False line_data['emission'].extend([emission_flag for row in line_array]) line_data['major'].extend(line_array['major']) self.cds_spectral_lines = ColumnDataSource(line_data)