import collections
import astropy.units as u
import numpy as np
from astromodels.core.spectral_component import SpectralComponent
from astromodels.core.tree import Node
from astromodels.core.units import get_units
from astromodels.functions import Constant
from astromodels.sources.source import Source, SourceType
from astromodels.utils.pretty_list import dict_to_list
from astromodels.utils.logging import setup_logger
log = setup_logger(__name__)
[docs]
class ExtendedSource(Source, Node):
def __init__(
self, source_name, spatial_shape, spectral_shape=None, components=None, polarization=None
):
# Check that we have all the required information
# and set the units
current_u = get_units()
if spatial_shape.n_dim == 2:
# Now gather the component(s)
# We need either a single component, or a list of components, but not both
# (that's the ^ symbol)
assert (spectral_shape is not None) ^ (components is not None), (
"You have to provide either a single "
"component, or a list of components "
"(but not both)."
)
# If the user specified only one component, make a list of one element with a default name ("main")
if spectral_shape is not None:
components = [SpectralComponent("main", spectral_shape, polarization)]
# Components in this case have energy as x and differential flux as y
diff_flux_units = (current_u.energy * current_u.area * current_u.time) ** (
-1
)
# Now set the units of the components
for component in components:
component.shape.set_units(current_u.energy, diff_flux_units)
# Set the units of the brightness
spatial_shape.set_units(
current_u.angle, current_u.angle, current_u.angle ** (-2)
)
elif spatial_shape.n_dim == 3:
# If there is no spectral component then assume that the input is a template, which will provide the
# spectrum by itself. We just use a renormalization (a bias)
if spectral_shape is None and components is None:
# This is a template. Add a component which is just a renormalization
spectral_shape = Constant()
components = [SpectralComponent("main", spectral_shape)]
# set the units
diff_flux_units = (
current_u.energy
* current_u.area
* current_u.time
* current_u.angle**2
) ** (-1)
spatial_shape.set_units(
current_u.angle,
current_u.angle,
current_u.energy,
diff_flux_units,
)
else:
# the spectral shape has been given, so this is a case where the spatial template gives an
# energy-dependent shape and the spectral components give the spectrum
if not ((spectral_shape is not None) ^ (components is not None)):
log.error(
"You can provide either a single "
"component, or a list of components "
"(but not both)."
)
raise AssertionError()
if spectral_shape is not None:
components = [SpectralComponent("main", spectral_shape, polarization)]
# Assign units
diff_flux_units = (
current_u.energy * current_u.area * current_u.time
) ** (-1)
# Now set the units of the components
for component in components:
component.shape.set_units(current_u.energy, diff_flux_units)
# Set the unit of the spatial template
spatial_shape.set_units(
current_u.angle,
current_u.angle,
current_u.energy,
current_u.angle ** (-2),
)
else:
log.error("The spatial shape must have either 2 or 3 dimensions.")
raise RuntimeError()
# Here we have a list of components
Source.__init__(self, components, SourceType.EXTENDED_SOURCE)
# A source is also a Node in the tree
Node.__init__(self, source_name)
# Add the spatial shape as a child node, with an explicit name
self._spatial_shape = spatial_shape
self._add_child(self._spatial_shape)
# Add the same node also with the name of the function
# self._add_child(self._shape, self._shape.__name__)
# Add a node called 'spectrum'
spectrum_node = Node("spectrum")
spectrum_node._add_children(list(self._components.values()))
self._add_child(spectrum_node)
@property
def spatial_shape(self):
"""
A generic name for the spatial shape.
:return: the spatial shape instance
"""
return self._spatial_shape
[docs]
def get_spatially_integrated_flux(self, energies):
"""
Returns total flux of source at the given energy
:param energies: energies (array or float)
:return: differential flux at given energy
"""
if not isinstance(energies, np.ndarray):
energies = np.array(energies, ndmin=1)
# Get the differential flux from the spectral components
results = [
self.spatial_shape.get_total_spatial_integral(energies)
* component.shape(energies)
for component in self.components.values()
]
if isinstance(energies, u.Quantity):
# Slow version with units
# We need to sum like this (slower) because using np.sum will not preserve the units
# (thanks astropy.units)
differential_flux = sum(results)
else:
# Fast version without units, where x is supposed to be in the same units as currently defined in
# units.get_units()
differential_flux = np.sum(results, 0)
return differential_flux
def __call__(self, lon, lat, energies):
"""
Returns brightness of source at the given position and energy
:param lon: longitude (array or float)
:param lat: latitude (array or float)
:param energies: energies (array or float)
:return: differential flux at given position and energy
"""
assert type(lat) is type(lon) and type(lon) is type(
energies
), "Type mismatch in input of call"
if not isinstance(lat, np.ndarray):
lat = np.array(lat, ndmin=1)
lon = np.array(lon, ndmin=1)
energies = np.array(energies, ndmin=1)
# Get the differential flux from the spectral components
results = [
component.shape(energies) for component in list(self.components.values())
]
if isinstance(energies, u.Quantity):
# Slow version with units
# We need to sum like this (slower) because using np.sum will not preserve the units
# (thanks astropy.units)
differential_flux = sum(results)
else:
# Fast version without units, where x is supposed to be in the same units as currently defined in
# units.get_units()
differential_flux = np.sum(results, 0)
# Get brightness from spatial model
if self._spatial_shape.n_dim == 2:
brightness = self._spatial_shape(lon, lat)
# In this case the spectrum is the same everywhere
n_points = lat.shape[0]
n_energies = differential_flux.shape[0]
# The following is a little obscure, but it is 6x faster than doing a for loop
cube = (
np.repeat(differential_flux, n_points).reshape(n_energies, n_points).T
)
result = (cube.T * brightness).T
else:
result = self._spatial_shape(lon, lat, energies) * differential_flux
# Do not clip the output, otherwise it will not be possible to use ext. sources
# with negative fluxes
return np.squeeze(result)
@property
def has_free_parameters(self):
"""
Returns True or False whether there is any parameter in this source
:return:
"""
for component in list(self._components.values()):
for par in list(component.shape.parameters.values()):
if par.free:
return True
for par in list(self.spatial_shape.parameters.values()):
if par.free:
return True
return False
@property
def free_parameters(self):
"""
Returns a dictionary of free parameters for this source
We use the parameter path as the key because it's
guaranteed to be unique, unlike the parameter name.
:return:
"""
free_parameters = collections.OrderedDict()
for component in list(self._components.values()):
for par in list(component.shape.parameters.values()):
if par.free:
free_parameters[par.path] = par
for par in list(self.spatial_shape.parameters.values()):
if par.free:
free_parameters[par.path] = par
return free_parameters
@property
def parameters(self):
"""
Returns a dictionary of all parameters for this source.
We use the parameter path as the key because it's
guaranteed to be unique, unlike the parameter name.
:return:
"""
all_parameters = collections.OrderedDict()
for component in list(self._components.values()):
for par in list(component.shape.parameters.values()):
all_parameters[par.path] = par
for par in list(self.spatial_shape.parameters.values()):
all_parameters[par.path] = par
return all_parameters
def _repr__base(self, rich_output=False):
"""
Representation of the object
:param rich_output: if True, generates HTML, otherwise text
:return: the representation
"""
# Make a dictionary which will then be transformed in a list
repr_dict = collections.OrderedDict()
key = "%s (extended source)" % self.name
repr_dict[key] = collections.OrderedDict()
repr_dict[key]["shape"] = self._spatial_shape.to_dict(minimal=True)
repr_dict[key]["spectrum"] = collections.OrderedDict()
for component_name, component in list(self.components.items()):
repr_dict[key]["spectrum"][component_name] = component.to_dict(minimal=True)
return dict_to_list(repr_dict, rich_output)
[docs]
def get_boundaries(self):
"""
Returns the boundaries for this extended source
:return: a tuple of tuples ((min. lon, max. lon), (min lat, max lat))
"""
return self._spatial_shape.get_boundaries()