Source code for astromodels.core.model_parser

from builtins import object, str

__author__ = "giacomov"

import re
import warnings
from typing import Any, Dict, List, Optional, Union

from astromodels.core import (
    model,
    parameter,
    polarization,
    sky_direction,
    spectral_component,
)
from astromodels.core.my_yaml import my_yaml
from astromodels.functions import function
from astromodels.sources import extended_source, particle_source, point_source
from astromodels.sources.source import SourceType
from astromodels.utils.logging import setup_logger

log = setup_logger(__name__)


[docs] class ModelIOError(IOError): pass
[docs] class ModelYAMLError(my_yaml.YAMLError): pass
[docs] class ModelSyntaxError(RuntimeError): pass
[docs] def load_model(filename): """ Load a model from a file. :param filename: the name of the file containing the model :return: an instance of a Model """ parser = ModelParser(filename) return parser.get_model()
[docs] def clone_model(model_instance): """ Returns a copy of the given model with all objects cloned. This is equivalent to saving the model to a file and reload it, but it doesn't require writing or reading to/from disk. The original model is not touched. :param model: model to be cloned :return: a cloned copy of the given model """ data = model_instance.to_dict_with_types() parser = ModelParser(model_dict=data) return parser.get_model()
[docs] def model_unpickler(state): return ModelParser(model_dict=state).get_model()
[docs] class ModelParser(object): def __init__(self, model_file=None, model_dict=None): if not ((model_file is not None) or (model_dict is not None)): log.error( "You have to provide either a model file or a " "model dictionary" ) raise AssertionError() if model_file is not None: # Read model file and deserialize into a dictionary try: with open(model_file) as f: self._model_dict = my_yaml.load(f, Loader=my_yaml.FullLoader) except IOError: log.error( "File %s cannot be read. Check path and permissions for current user." % model_file ) raise ModelIOError() except my_yaml.YAMLError: log.error("Could not parse file %s. Check your syntax." % model_file) raise ModelYAMLError() else: self._model_dict = model_dict self._parse() def _parse(self): # Traverse the dictionary and create all the needed classes # The first level is the source level self._sources = [] self._independent_variables = [] self._external_parameters = [] self._links = [] self._external_parameter_links = [] self._extra_setups = [] self._external_functions = [] for source_or_var_name, source_or_var_definition in list( self._model_dict.items() ): # first look for independent variable if source_or_var_name.find("(IndependentVariable)") > 0: var_name = source_or_var_name.split("(")[0].replace(" ", "") this_parser = IndependentVariableParser( var_name, source_or_var_definition ) res = this_parser.get_variable() assert isinstance(res, parameter.IndependentVariable) self._independent_variables.append(res) elif source_or_var_name.find("(Parameter)") > 0: var_name = source_or_var_name.split("(")[0].replace(" ", "") this_parser = ParameterParser(var_name, source_or_var_definition) res = this_parser.get_variable() assert isinstance(res, parameter.Parameter) self._external_parameters.append(res) self._links.extend(this_parser.links) # self._external_parameter_links.extend(this_parser.links) else: this_parser = SourceParser(source_or_var_name, source_or_var_definition) res = this_parser.get_source() assert ( isinstance(res, point_source.PointSource) or isinstance(res, extended_source.ExtendedSource) or isinstance(res, particle_source.ParticleSource) ) self._sources.append(res) self._links.extend(this_parser.links) self._extra_setups.extend(this_parser.extra_setups) self._external_functions.extend(this_parser.external_functions)
[docs] def get_model(self): # Instance the model with all the parsed sources new_model = model.Model(*self._sources) # Now set up IndependentVariable instances (if any) for independent_variable in self._independent_variables: new_model.add_independent_variable(independent_variable) # Now set up external parameters (if any) for par in self._external_parameters: new_model.add_external_parameter(par) # Now set up the links for link in self._links: path = link["parameter_path"] variable = link["variable"] law = link["law"] new_model[path].add_auxiliary_variable(new_model[variable], law) # the extra_setups (if any) for extra_setup in self._extra_setups: path = extra_setup["function_path"] for property, value in list(extra_setup["extra_setup"].items()): log.debug(f"adding {property} with {value}") # First, check to see if the we have a valid path in the new model. # If we aren't given a path, interpret it as being given a value. if value in new_model: new_model[path].__setattr__(property, new_model[value]) else: new_model[path].__setattr__(property, value) # finally the external functions if any for external_function in self._external_functions: path = external_function["function_path"] if external_function["is_composite"]: # we need to loop through the sub functions # can link them for i, primary_func in enumerate(new_model[path]._functions): this_ef = external_function["external_functions"][i] # for each function if this_ef: # if there are extrenal linked functions for fname, linked_function in this_ef.items(): # relink them primary_func.link_external_function( function=new_model[linked_function], internal_name=fname, ) else: # do the same if it is not composite for fname, linked_function in external_function[ "external_functions" ].items(): new_model[path].link_external_function( function=new_model[linked_function], internal_name=fname ) return new_model
[docs] class IndependentVariableParser(object): def __init__(self, name, definition): self._variable = parameter.IndependentVariable(name, **definition)
[docs] def get_variable(self): return self._variable
[docs] class ParameterParser(object): def __init__(self, name, definition): self._links = [] # NOTE: this is triggered only for parameters outside of functions if "prior" in definition: # Need the create a function for the prior first try: function_name = list(definition["prior"].keys())[0] parameters_definition = definition["prior"][function_name] except KeyError: # pragma: no cover log.error("The prior for parameter %s is malformed" % name) raise ModelSyntaxError() # parse the function shape_parser = ShapeParser(name) prior_instance = shape_parser.parse( name, function_name, parameters_definition ) # Substitute the definition with the instance, so that the following constructor will work definition["prior"] = prior_instance # Check if this is a linked parameter, i.e., if 'value' is something like f(source.spectrum.powerlaw.index) matches = re.findall("""f\((.+)\)""", str(definition["value"])) if matches: # This is an expression which marks a parameter # with a link to another parameter (or an IndependentVariable such as time) # Get the variable linked_variable = matches[0] # Now get the law if "law" not in definition: # pragma: no cover log.error( "The parameter %s in function %s " " is linked to %s but lacks a 'law' attribute" % (name, function_name, linked_variable) ) raise ModelSyntaxError() link_function_name = list(definition["law"].keys())[0] # ok, now we parse the linked parameter function_parser = ShapeParser(name) link_function_instance = function_parser.parse( name, link_function_name, definition["law"][link_function_name] ) self._links.append( { "parameter_path": name, "law": link_function_instance, "variable": linked_variable, } ) # get rid of the 'law' entry definition.pop("law", None) # this parameter's value will be replaced later. # for now we just need to get rid of the f(param) entry definition["value"] = 1.0 self._variable = parameter.Parameter(name, **definition)
[docs] def get_variable(self): return self._variable
@property def links(self): return self._links
[docs] class SourceParser(object): def __init__(self, source_name, source_definition): # Get the type of the source try: # Point source or extended source? source_type = re.findall( "\((%s|%s|%s)\)" % ( SourceType.POINT_SOURCE, SourceType.EXTENDED_SOURCE, SourceType.PARTICLE_SOURCE, ), source_name, )[-1] except IndexError: # pragma: no cover log.error( "Don't recognize type for source '%s'. " "Valid types are '%s', '%s' or '%s'." % ( source_name, SourceType.POINT_SOURCE, SourceType.EXTENDED_SOURCE, SourceType.PARTICLE_SOURCE, ) ) raise ModelSyntaxError() else: # Strip the source_type from the name source_name = source_name.split()[0] self._source_name = source_name # This will store the links (if any) self._links = [] # This will store extra_setups (if any), used sometimes. For example, the function which uses naima # to make a synchrotron spectrum uses this to save and set up the particle distribution self._extra_setups = [] # this will store any externally linked functions self._external_functions = [] if source_type == SourceType.POINT_SOURCE.value: self._parsed_source = self._parse_point_source(source_definition) elif source_type == SourceType.EXTENDED_SOURCE.value: self._parsed_source = self._parse_extended_source(source_definition) elif source_type == SourceType.PARTICLE_SOURCE.value: self._parsed_source = self._parse_particle_source(source_definition) @property def extra_setups(self): return self._extra_setups @property def external_functions(self) -> List[Dict[str, str]]: return self._external_functions @property def links(self): return self._links
[docs] def get_source(self): return self._parsed_source
def _parse_particle_source(self, particle_source_definition): # Parse the spectral information try: spectrum = particle_source_definition["spectrum"] except KeyError: # pragma: no cover log.error( "Point source %s is missing the 'spectrum' attribute" % self._source_name ) raise ModelSyntaxError() components = [] for component_name, component_definition in list( particle_source_definition["spectrum"].items() ): this_component = self._parse_spectral_component( component_name, component_definition ) components.append(this_component) this_particle_source = particle_source.ParticleSource( self._source_name, components=components ) return this_particle_source def _parse_point_source(self, pts_source_definition): # Parse the positional information try: position_definition = pts_source_definition["position"] except KeyError: # pragma: no cover log.error( "Point source %s is missing the 'position' attribute" % self._source_name ) raise ModelSyntaxError() this_sky_direction = self._parse_sky_direction(position_definition) # Parse the spectral information try: _ = pts_source_definition["spectrum"] except KeyError: # pragma: no cover log.error( "Point source %s is missing the 'spectrum' attribute" % self._source_name ) raise ModelSyntaxError() components = [] for component_name, component_definition in list( pts_source_definition["spectrum"].items() ): this_component = self._parse_spectral_component( component_name, component_definition ) components.append(this_component) # try: # this_component = self._parse_spectral_component( # component_name, component_definition # ) # components.append(this_component) # except: # raise this_point_source = point_source.PointSource( self._source_name, sky_position=this_sky_direction, components=components, ) # try: # this_point_source = point_source.PointSource( # self._source_name, # sky_position=this_sky_direction, # components=components, # ) # except: # raise return this_point_source def _parse_sky_direction(self, sky_direction_definition): # Instance the SkyDirection class using the coordinates provided coordinates = {} if "ra" in sky_direction_definition and "dec" in sky_direction_definition: par_parser = ParameterParser("ra", sky_direction_definition["ra"]) ra = par_parser.get_variable() if ra.bounds == (None, None): ra.bounds = (0, 360) par_parser = ParameterParser("dec", sky_direction_definition["dec"]) dec = par_parser.get_variable() if dec.bounds == (None, None): dec.bounds = (-90, 90) coordinates["ra"] = ra coordinates["dec"] = dec elif "l" in sky_direction_definition and "b" in sky_direction_definition: par_parser = ParameterParser("l", sky_direction_definition["l"]) l = par_parser.get_variable() if l.bounds == (None, None): l.bounds = (0, 360) par_parser = ParameterParser("b", sky_direction_definition["b"]) b = par_parser.get_variable() if b.bounds == (None, None): b.bounds = (-90, 90) coordinates["l"] = l coordinates["b"] = b else: # pragma: no cover log.error( "Position specification for source %s has an invalid coordinate pair. " " You need to specify either 'ra' and 'dec', or 'l' and 'b'." % self._source_name ) raise ModelSyntaxError() # Check if there is a equinox specification if "equinox" in sky_direction_definition: coordinates["equinox"] = sky_direction_definition["equinox"] try: this_sky_direction = sky_direction.SkyDirection(**coordinates) except sky_direction.WrongCoordinatePair: # pragma: no cover log.error( "Position specification for source %s has an invalid coordinate pair" % self._source_name ) raise ModelSyntaxError() return this_sky_direction def _parse_polarization(self, polarization_definititon): polarization_params = {} if "degree" in polarization_definititon and "angle" in polarization_definititon: par_dict = {'degree': None, 'angle': None} par_names = list(polarization_definititon.keys()) par_bounds = {'degree':(0,100),'angle':(0,180)} for par in par_names: if list(polarization_definititon[par].keys())[0] == 'value': par_parser = ParameterParser(par, polarization_definititon[par]) par_dict[par] = par_parser.get_variable() par_dict[par].bounds = par_bounds[par] else: try: function_name = list(polarization_definititon[par].keys())[0] parameters_definition = polarization_definititon[par][function_name] # parse the function shape_parser = ShapeParser(self._source_name) shape = shape_parser.parse(par, function_name, parameters_definition, is_spatial=False) par_dict[par] = shape except KeyError: # pragma: no cover raise ModelSyntaxError("The polarization_definititon of source %s is malformed" % (self._source_name)) this_polarization = polarization.LinearPolarization(**par_dict) elif 'I' in polarization_definititon or 'U' in polarization_definititon or 'Q' in polarization_definititon or 'V' in polarization_definititon: par_dict = {'I': None, 'Q':None, 'U':None, 'V':None} par_names = list(polarization_definititon.keys()) for par in par_names: try: function_name = list(polarization_definititon[par].keys())[0] parameters_definition = polarization_definititon[par][function_name] # parse the function shape_parser = ShapeParser(self._source_name) shape = shape_parser.parse(par, function_name, parameters_definition, is_spatial=False) par_dict[par] = shape except KeyError: # pragma: no cover raise ModelSyntaxError("The polarization_definititon of source %s is malformed" % (self._source_name)) this_polarization = polarization.StokesPolarization(**par_dict) else: # just make a default polarization this_polarization = polarization.Polarization() # raise ModelSyntaxError("Polarization specification for source %s has an invalid parameters. " # " You need to specify either 'angle' and 'degree', or 'I' ,'Q', 'U' and 'V'." # % self._source_name) return this_polarization def _parse_spectral_component(self, component_name, component_definition): # Parse the shape definition, which is the first to occur try: function_name = list(component_definition.keys())[0] parameters_definition = component_definition[function_name] except KeyError: # pragma: no cover log.error( "The component %s of source %s is malformed" % (component_name, self._source_name) ) raise ModelSyntaxError() # parse the function # now split the parameters and the properties shape_parser = ShapeParser(self._source_name) shape = shape_parser.parse( component_name, function_name, parameters_definition, is_spatial=False, ) # Get the links and extra setups, if any self._links.extend(shape_parser.links) self._extra_setups.extend(shape_parser.extra_setups) self._external_functions.extend(shape_parser.external_functions) if "polarization" in component_definition: # get the polarization polarization_definition = component_definition["polarization"] this_polarization = self._parse_polarization(polarization_definition) else: this_polarization = polarization.Polarization() this_spectral_component = spectral_component.SpectralComponent( component_name, shape, this_polarization ) return this_spectral_component def _parse_extended_source(self, ext_source_definition): # The first item in the dictionary is the definition of the extended shape name_of_spatial_shape = list(ext_source_definition.keys())[0] spatial_shape_parser = ShapeParser(self._source_name) spatial_shape = spatial_shape_parser.parse( "n.a.", name_of_spatial_shape, list(ext_source_definition.values())[0], is_spatial=True, ) # Get the links and extra setups, if any self._links.extend(spatial_shape_parser.links) self._extra_setups.extend(spatial_shape_parser.extra_setups) self._external_functions.extend(spatial_shape_parser.external_functions) # Parse the spectral information try: spectrum = ext_source_definition["spectrum"] except KeyError: # pragma: no cover log.error( "Ext. source %s is missing the 'spectrum' attribute" % self._source_name ) raise ModelSyntaxError() components = [] for component_name, component_definition in list( ext_source_definition["spectrum"].items() ): this_component = self._parse_spectral_component( component_name, component_definition ) components.append(this_component) this_ext_source = extended_source.ExtendedSource( self._source_name, spatial_shape, components=components ) return this_ext_source
[docs] class ShapeParser(object): def __init__(self, source_name): self._source_name = source_name self._links = [] self._extra_setups = [] self._external_functions = [] @property def links(self): return self._links @property def extra_setups(self): return self._extra_setups @property def external_functions(self): return self._external_functions
[docs] def parse( self, component_name, function_name, parameters_definition, is_spatial=False, ): return self._parse_shape_definition( component_name, function_name, parameters_definition, is_spatial )
@staticmethod def _fix(value): # Remove new lines where it shouldn't be any # Sometimes YAML add new lines in the middle of definitions, # such as in units return value.replace("\n", " ") def _parse_shape_definition( self, component_name, function_name, parameters_definition, is_spatial=False, ): # Get the function if "expression" in parameters_definition: # This is a composite function function_instance = function.get_function( function_name, parameters_definition["expression"] ) is_composite = True else: try: function_instance = function.get_function(function_name) is_composite = False except function.UnknownFunction: # pragma: no cover log.error( "Function %s, specified as shape for %s of source %s, is not a " "known function" % (function_name, component_name, self._source_name) ) raise ModelSyntaxError() # Loop over the parameters of the function instance, instead of the specification, # so we can understand if there are parameters missing from the specification for parameter_name, _ in function_instance.parameters.items(): try: this_definition = parameters_definition[parameter_name] except KeyError: # pragma: no cover log.error( "Function %s, specified as shape for %s of source %s, lacks " "the definition for parameter %s" % ( function_name, component_name, self._source_name, parameter_name, ) ) for k, v in parameters_definition.items(): log.error((k, v)) raise ModelSyntaxError() # Update the parameter. Note that the order is important, because trying to set the value before the # minimum and maximum could result in a error. # All these specifications are optional. If they are not present, then the default value # already contained in the instance of the function will be used # Ignore for a second the RuntimeWarning that is printed if the default value in the function definition # is outside the bounds defined here with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) if "min_value" in this_definition: function_instance.parameters[ parameter_name ].min_value = this_definition["min_value"] if "max_value" in this_definition: function_instance.parameters[ parameter_name ].max_value = this_definition["max_value"] if "delta" in this_definition: function_instance.parameters[parameter_name].delta = this_definition[ "delta" ] if "free" in this_definition: function_instance.parameters[parameter_name].free = this_definition[ "free" ] if "unit" in this_definition: function_instance.parameters[parameter_name].unit = self._fix( this_definition["unit"] ) # Now set the value, which must be present if "value" not in this_definition: # pragma: no cover log.error( "The parameter %s in function %s, specified as shape for %s " "of source %s, lacks a 'value' attribute" % ( parameter_name, function_name, component_name, self._source_name, ) ) raise ModelSyntaxError() # Check if this is a linked parameter, i.e., if 'value' is something like f(source.spectrum.powerlaw.index) matches = re.findall("""f\((.+)\)""", str(this_definition["value"])) if matches: # This is an expression which marks a parameter # with a link to another parameter (or an IndependentVariable such as time) # Get the variable linked_variable = matches[0] # Now get the law if "law" not in this_definition: # pragma: no cover log.error( "The parameter %s in function %s, specified as shape for %s " "of source %s, is linked to %s but lacks a 'law' attribute" % ( parameter_name, function_name, component_name, self._source_name, linked_variable, ) ) raise ModelSyntaxError() link_function_name = list(this_definition["law"].keys())[0] link_function_instance = self._parse_shape_definition( component_name, link_function_name, this_definition["law"][link_function_name], ) if is_spatial: path = ".".join([self._source_name, function_name, parameter_name]) else: path = ".".join( [ self._source_name, "spectrum", component_name, function_name, parameter_name, ] ) self._links.append( { "parameter_path": path, "law": link_function_instance, "variable": linked_variable, } ) else: # This is a normal (not linked) parameter function_instance.parameters[parameter_name].value = this_definition[ "value" ] # Setup the prior for this parameter, if it exists if "prior" in this_definition: # Get the function for this prior # A name to display in case of errors name_for_errors = ( "prior for %s" % function_instance.parameters[parameter_name].path ) prior_function_name = list(this_definition["prior"].keys())[0] prior_function_definition = this_definition["prior"][ prior_function_name ] prior_function = self._parse_shape_definition( name_for_errors, prior_function_name, prior_function_definition, ) # Set it as prior for current parameter function_instance.parameters[parameter_name].prior = prior_function if function_instance.has_properties: # now collect the properties # the properties are stored in the parameters defintion # as well for property_name, _ in function_instance.properties.items(): try: this_definition = parameters_definition[property_name] except KeyError: # pragma: no cover log.error( "Function %s, specified as shape for %s of source %s, lacks " "the definition for property %s" % ( function_name, component_name, self._source_name, property_name, ) ) for k, v in parameters_definition.items(): log.error((k, v)) raise ModelSyntaxError() if "value" not in this_definition: log.error( "The property %s in function %s, specified as shape for %s " "of source %s, lacks a 'value' attribute" % ( property_name, function_name, component_name, self._source_name, ) ) raise ModelSyntaxError() function_instance.properties[property_name].value = this_definition[ "value" ] # Now handle extra_setup if any if "extra_setup" in parameters_definition: if is_spatial: path = ".".join([self._source_name, function_name]) else: path = ".".join( [ self._source_name, "spectrum", component_name, function_name, ] ) self._extra_setups.append( { "function_path": path, "extra_setup": parameters_definition["extra_setup"], } ) if "external_functions" in parameters_definition: if is_spatial: path = ".".join([self._source_name, function_name]) else: path = ".".join( [ self._source_name, "spectrum", component_name, function_name, ] ) self._external_functions.append( { "function_path": path, "external_functions": parameters_definition["external_functions"], "is_composite": is_composite, } ) return function_instance