Source code for dwave.system.warnings

# Copyright 2019 D-Wave Systems Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import enum
import logging

import dimod
import numpy as np
import collections.abc as abc

from dwave.embedding import broken_chains


[docs]class WarningAction(enum.Enum): IGNORE = 'ignore' SAVE = 'save'
# we may eventually want to support logging and raising python Warnings # LOG = 'log' # RAISE = 'raise' IGNORE = WarningAction.IGNORE SAVE = WarningAction.SAVE # LOG = WarningAction.LOG # RAISE = WarningAction.RAISE def as_action(action): if isinstance(action, WarningAction): return action elif isinstance(action, str): return WarningAction[action.upper()] else: raise TypeError('unknown warning action provided')
[docs]class ChainBreakWarning(UserWarning): pass
[docs]class ChainLengthWarning(UserWarning): pass
[docs]class TooFewSamplesWarning(UserWarning): pass
[docs]class ChainStrengthWarning(UserWarning): """Base category for warnings about the embedding chain strength.""" pass
[docs]class EnergyScaleWarning(UserWarning): """Base category for warnings about the relative bias strengths.""" pass
[docs]class WarningHandler(object): def __init__(self, action=None): self.saved = [] if action is not None: # promote from class attribute to object attribute self.action = as_action(action) action = WarningAction.IGNORE # the default # todo: let user override __init__ parameters with kwargs def issue(self, msg, category=None, func=None, level=logging.WARNING, data=None): """Issue a warning. Args: msg (str): The warning message category (Warning): The warning category class. Defaults to UserWarning. level (int): The level of warning severity. Uses the logging warning levels. func (function): A function that is executed in the case that the warning level is not IGNORE. The function should return a 2-tuple containing a bool specifying whether the warning should be saved/raised and any relevent data associated with the warning as a dictionary/None. This overrides anything provided in the `data` kwarg. data (dict): Any data relevent to the warning. """ action = as_action(self.action) # user may have overwritten if action is IGNORE: return if func is not None: valid, data = func() if not valid: return if category is None: category = UserWarning if data is None: data = {} if action is SAVE: self.saved.append(dict(type=category, message=msg, level=level, data=data)) else: raise TypeError("unknown action") # some hard-coded warnings for convenience or for expensive operations def chain_length(self, embedding, length=7): if as_action(self.action) is IGNORE: return for v, chain in embedding.items(): if len(chain) <= length: continue self.issue("Chain length greater than {}".format(length), category=ChainLengthWarning, data=dict(target_variables=chain, source_variables=[v]), ) def chain_break(self, sampleset, embedding): if as_action(self.action) is IGNORE: return ground = sampleset.lowest() variables = list(embedding) chains = [embedding[v] for v in variables] broken = broken_chains(ground, chains) if not (len(sampleset) and broken.any()): return for nc, chain in enumerate(chains): for row in range(broken.shape[0]): if not broken[row, nc]: continue self.issue("Lowest-energy samples contain a broken chain", category=ChainBreakWarning, level=logging.ERROR, data=dict(target_variables=chain, source_variables=[variables[nc]], sample_index=row), ) def chain_strength(self, bqm, chain_strength, embedding=None): """Issues a warning when any quadratic biases are greater than the given chain strength.""" if as_action(self.action) is IGNORE: return if embedding is not None: if not embedding or all(len(chain) <= 1 for chain in embedding.values()): # the chains are all length 1 so don't have to worry about # strength return if isinstance(chain_strength, abc.Mapping): interactions = [(u, v) for (u, v), bias in bqm.quadratic.items() if abs(bias) >= min(chain_strength[u], chain_strength[v])] else: interactions = [uv for uv, bias in bqm.quadratic.items() if abs(bias) >= chain_strength] if interactions: self.issue("Some quadratic biases are stronger than the given " "chain strength", category=ChainStrengthWarning, level=logging.WARNING, data=dict(source_interactions=interactions)) def energy_scale(self, bqm): """Issues a warning if some biases are 10^3 times stronger than others. Args: bqm (:class:`dimod.BinaryQuadraticModel`/tuple): A binary quadratic model, a tuple of the form `(Q)` where `Q` is a QUBO-dictionary, or a tuple of the form `(h, J)` where `h` and `J` are Ising problem dictionaries. """ if as_action(self.action) is IGNORE: return if isinstance(bqm, tuple): if len(bqm) == 1: bqm = dimod.BinaryQuadraticModel.from_qubo(*bqm) elif len(bqm) == 2: bqm = dimod.BinaryQuadraticModel.from_ising(*bqm) else: raise TypeError("bqm should be a binary quadratic model, a " "1-tuple or a 2-tuple") max_bias = max(map(abs, bqm.linear.values())) if bqm.quadratic: max_bias = max(max_bias, max(map(abs, bqm.quadratic.values()))) max_bias *= 10 ** -3 variables = [v for v, bias in bqm.linear.items() if abs(bias) < max_bias] interactions = [uv for uv, bias in bqm.quadratic.items() if abs(bias) < max_bias] data = dict() if variables: data.update(source_variables=variables) if interactions: data.update(source_interactions=interactions) if data: self.issue("Some biases are 10^3 times stronger than others", category=EnergyScaleWarning, level=logging.WARNING, data=data) def too_few_samples(self, sampleset): """Issues a warning when the number ground states found is within the sampling error threshold.""" if self.action is IGNORE: return ground = sampleset.lowest() total_ground = np.sum(ground.record.num_occurrences) total_samples = np.sum(sampleset.record.num_occurrences) if total_ground <= np.sqrt(total_samples): self.issue("Number of ground states found is within sampling error", category=TooFewSamplesWarning, level=logging.WARNING, data=dict(number_of_ground_states=total_ground, num_reads=total_samples, sampling_error_rate=np.sqrt(total_samples)), )