import json import random import string import warnings import numpy as np from . import colors try: from IPython.display import HTML from IPython.display import display as ipython_display have_ipython = True except ImportError: have_ipython = False # TODO: we should support text output explanations (from models that output text not numbers), this would require the force # the force plot and the coloring to update based on mouseovers (or clicks to make it fixed) of the output text def text( shap_values, num_starting_labels=0, grouping_threshold=0.01, separator="", xmin=None, xmax=None, cmax=None, display=True, ): """Plots an explanation of a string of text using coloring and interactive labels. The output is interactive HTML and you can click on any token to toggle the display of the SHAP value assigned to that token. Parameters ---------- shap_values : [numpy.array] List of arrays of SHAP values. Each array has the shap values for a string (#input_tokens x output_tokens). num_starting_labels : int Number of tokens (sorted in descending order by corresponding SHAP values) that are uncovered in the initial view. When set to 0, all tokens are covered. grouping_threshold : float If the component substring effects are less than a ``grouping_threshold`` fraction of an unlowered interaction effect, then we visualize the entire group as a single chunk. This is primarily used for explanations that were computed with fixed_context set to 1 or 0 when using the :class:`.explainers.Partition` explainer, since this causes interaction effects to be left on internal nodes rather than lowered. separator : string The string separator that joins tokens grouped by interaction effects and unbroken string spans. Defaults to the empty string ``""``. xmin : float Minimum shap value bound. xmax : float Maximum shap value bound. cmax : float Maximum absolute shap value for sample. Used for scaling colors for input tokens. display: bool Whether to display or return html to further manipulate or embed. Default: ``True`` Examples -------- See `text plot examples `_. """ def values_min_max(values, base_values): """Used to pick our axis limits.""" fx = base_values + values.sum() xmin = fx - values[values > 0].sum() xmax = fx - values[values < 0].sum() cmax = max(abs(values.min()), abs(values.max())) d = xmax - xmin xmin -= 0.1 * d xmax += 0.1 * d return xmin, xmax, cmax uuid = "".join(random.choices(string.ascii_lowercase, k=20)) # loop when we get multi-row inputs if len(shap_values.shape) == 2 and (shap_values.output_names is None or isinstance(shap_values.output_names, str)): xmin = 0 xmax = 0 cmax = 0 for i, v in enumerate(shap_values): values, clustering = unpack_shap_explanation_contents(v) tokens, values, group_sizes = process_shap_values(v.data, values, grouping_threshold, separator, clustering) if i == 0: xmin, xmax, cmax = values_min_max(values, v.base_values) continue xmin_i, xmax_i, cmax_i = values_min_max(values, v.base_values) if xmin_i < xmin: xmin = xmin_i if xmax_i > xmax: xmax = xmax_i if cmax_i > cmax: cmax = cmax_i out = "" for i, v in enumerate(shap_values): out += f"""

[{i}]
""" out += text( v, num_starting_labels=num_starting_labels, grouping_threshold=grouping_threshold, separator=separator, xmin=xmin, xmax=xmax, cmax=cmax, display=False, ) if display: _ipython_display_html(out) return else: return out if len(shap_values.shape) == 2 and shap_values.output_names is not None: xmin_computed = None xmax_computed = None cmax_computed = None for i in range(shap_values.shape[-1]): values, clustering = unpack_shap_explanation_contents(shap_values[:, i]) tokens, values, group_sizes = process_shap_values( shap_values[:, i].data, values, grouping_threshold, separator, clustering ) # if i == 0: # xmin, xmax, cmax = values_min_max(values, shap_values[:,i].base_values) # continue xmin_i, xmax_i, cmax_i = values_min_max(values, shap_values[:, i].base_values) if xmin_computed is None or xmin_i < xmin_computed: xmin_computed = xmin_i if xmax_computed is None or xmax_i > xmax_computed: xmax_computed = xmax_i if cmax_computed is None or cmax_i > cmax_computed: cmax_computed = cmax_i if xmin is None: xmin = xmin_computed if xmax is None: xmax = xmax_computed if cmax is None: cmax = cmax_computed out = f"""
outputs
""" output_values = shap_values.values.sum(0) + shap_values.base_values output_max = np.max(np.abs(output_values)) for i, name in enumerate(shap_values.output_names): scaled_value = 0.5 + 0.5 * float(output_values[i]) / (float(output_max) + 1e-8) color = colors.red_transparent_blue(scaled_value) color = (float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3])) # '#dddddd' if i == 0 else '#ffffff' border-bottom: {'3px solid #000000' if i == 0 else 'none'}; out += f"""
{name}
""" out += "

" for i, name in enumerate(shap_values.output_names): out += f"
" out += text( shap_values[:, i], num_starting_labels=num_starting_labels, grouping_threshold=grouping_threshold, separator=separator, xmin=xmin, xmax=xmax, cmax=cmax, display=False, ) out += "
" out += f"" out += "
" if display: _ipython_display_html(out) return else: return out # text_to_text(shap_values) # return if len(shap_values.shape) == 3: xmin_computed = None xmax_computed = None cmax_computed = None for i in range(shap_values.shape[-1]): for j in range(shap_values.shape[0]): values, clustering = unpack_shap_explanation_contents(shap_values[j, :, i]) tokens, values, group_sizes = process_shap_values( shap_values[j, :, i].data, values, grouping_threshold, separator, clustering ) xmin_i, xmax_i, cmax_i = values_min_max(values, shap_values[j, :, i].base_values) if xmin_computed is None or xmin_i < xmin_computed: xmin_computed = xmin_i if xmax_computed is None or xmax_i > xmax_computed: xmax_computed = xmax_i if cmax_computed is None or cmax_i > cmax_computed: cmax_computed = cmax_i if xmin is None: xmin = xmin_computed if xmax is None: xmax = xmax_computed if cmax is None: cmax = cmax_computed out = "" for i, v in enumerate(shap_values): out += f"""

[{i}]
""" out += text( v, num_starting_labels=num_starting_labels, grouping_threshold=grouping_threshold, separator=separator, xmin=xmin, xmax=xmax, cmax=cmax, display=False, ) if display: _ipython_display_html(out) return else: return out # set any unset bounds xmin_new, xmax_new, cmax_new = values_min_max(shap_values.values, shap_values.base_values) if xmin is None: xmin = xmin_new if xmax is None: xmax = xmax_new if cmax is None: cmax = cmax_new values, clustering = unpack_shap_explanation_contents(shap_values) tokens, values, group_sizes = process_shap_values( shap_values.data, values, grouping_threshold, separator, clustering ) # build out HTML output one word one at a time top_inds = np.argsort(-np.abs(values))[:num_starting_labels] out = "" # ev_str = str(shap_values.base_values) # vsum_str = str(values.sum()) # fx_str = str(shap_values.base_values + values.sum()) # uuid = ''.join(random.choices(string.ascii_lowercase, k=20)) encoded_tokens = [t.replace("<", "<").replace(">", ">").replace(" ##", "") for t in tokens] output_name = shap_values.output_names if isinstance(shap_values.output_names, str) else "" out += svg_force_plot( values, shap_values.base_values, shap_values.base_values + values.sum(), encoded_tokens, uuid, xmin, xmax, output_name, ) out += ( "
inputs
" ) for i, token in enumerate(tokens): scaled_value = 0.5 + 0.5 * values[i] / (cmax + 1e-8) color = colors.red_transparent_blue(scaled_value) color = (float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3])) # display the labels for the most important words label_display = "none" wrapper_display = "inline" if i in top_inds: label_display = "block" wrapper_display = "inline-block" # create the value_label string value_label = "" if group_sizes[i] == 1: value_label = str(values[i].round(3)) else: value_label = str(values[i].round(3)) + " / " + str(group_sizes[i]) # the HTML for this token out += f"""
{value_label}
{token.replace("<", "<").replace(">", ">").replace(" ##", "")}
""" out += "
" if display: _ipython_display_html(out) return else: return out def process_shap_values(tokens, values, grouping_threshold, separator, clustering=None, return_meta_data=False): # See if we got hierarchical input data. If we did then we need to reprocess the # shap_values and tokens to get the groups we want to display M = len(tokens) if len(values) != M: # make sure we were given a partition tree if clustering is None: raise ValueError( "The length of the attribution values must match the number of " "tokens if shap_values.clustering is None! When passing hierarchical " "attributions the clustering is also required." ) # compute the groups, lower_values, and max_values groups = [[i] for i in range(M)] lower_values = np.zeros(len(values)) lower_values[:M] = values[:M] max_values = np.zeros(len(values)) max_values[:M] = np.abs(values[:M]) for i in range(clustering.shape[0]): li = int(clustering[i, 0]) ri = int(clustering[i, 1]) groups.append(groups[li] + groups[ri]) lower_values[M + i] = lower_values[li] + lower_values[ri] + values[M + i] max_values[i + M] = max(abs(values[M + i]) / len(groups[M + i]), max_values[li], max_values[ri]) # compute the upper_values upper_values = np.zeros(len(values)) def lower_credit(upper_values, clustering, i, value=0): if i < M: upper_values[i] = value return li = int(clustering[i - M, 0]) ri = int(clustering[i - M, 1]) upper_values[i] = value value += values[i] # lower_credit(upper_values, clustering, li, value * len(groups[li]) / (len(groups[li]) + len(groups[ri]))) # lower_credit(upper_values, clustering, ri, value * len(groups[ri]) / (len(groups[li]) + len(groups[ri]))) lower_credit(upper_values, clustering, li, value * 0.5) lower_credit(upper_values, clustering, ri, value * 0.5) lower_credit(upper_values, clustering, len(values) - 1) # the group_values comes from the dividends above them and below them group_values = lower_values + upper_values # merge all the tokens in groups dominated by interaction effects (since we don't want to hide those) new_tokens = [] new_values = [] group_sizes = [] # meta data token_id_to_node_id_mapping = np.zeros((M,)) collapsed_node_ids = [] def merge_tokens(new_tokens, new_values, group_sizes, i): # return at the leaves if i < M and i >= 0: new_tokens.append(tokens[i]) new_values.append(group_values[i]) group_sizes.append(1) # meta data collapsed_node_ids.append(i) token_id_to_node_id_mapping[i] = i else: # compute the dividend at internal nodes li = int(clustering[i - M, 0]) ri = int(clustering[i - M, 1]) dv = abs(values[i]) / len(groups[i]) # if the interaction level is too high then just treat this whole group as one token if max(max_values[li], max_values[ri]) < dv * grouping_threshold: new_tokens.append( separator.join([tokens[g] for g in groups[li]]) + separator + separator.join([tokens[g] for g in groups[ri]]) ) new_values.append(group_values[i]) group_sizes.append(len(groups[i])) # setting collapsed node ids and token id to current node id mapping metadata collapsed_node_ids.append(i) for g in groups[li]: token_id_to_node_id_mapping[g] = i for g in groups[ri]: token_id_to_node_id_mapping[g] = i # if interaction level is not too high we recurse else: merge_tokens(new_tokens, new_values, group_sizes, li) merge_tokens(new_tokens, new_values, group_sizes, ri) merge_tokens(new_tokens, new_values, group_sizes, len(group_values) - 1) # replance the incoming parameters with the grouped versions tokens = np.array(new_tokens) values = np.array(new_values) group_sizes = np.array(group_sizes) # meta data token_id_to_node_id_mapping = np.array(token_id_to_node_id_mapping) collapsed_node_ids = np.array(collapsed_node_ids) M = len(tokens) else: group_sizes = np.ones(M) token_id_to_node_id_mapping = np.arange(M) collapsed_node_ids = np.arange(M) if return_meta_data: return tokens, values, group_sizes, token_id_to_node_id_mapping, collapsed_node_ids else: return tokens, values, group_sizes def svg_force_plot(values, base_values, fx, tokens, uuid, xmin, xmax, output_name): def xpos(xval): return 100 * (xval - xmin) / (xmax - xmin + 1e-8) s = "" s += '' ### x-axis marks ### # draw x axis line s += '' # draw base value def draw_tick_mark(xval, label=None, bold=False, backing=False): s = "" s += f'' if not bold: if backing: s += f'{xval:g}' s += f'{xval:g}' else: if backing: s += f'{xval:g}' s += f'{xval:g}' if label is not None: s += f'{label}' return s xcenter = round((xmax + xmin) / 2, int(round(1 - np.log10(xmax - xmin + 1e-8)))) s += draw_tick_mark(xcenter) # np.log10(xmax - xmin) tick_interval = round((xmax - xmin) / 7, int(round(1 - np.log10(xmax - xmin + 1e-8)))) # tick_interval = (xmax - xmin) / 7 side_buffer = (xmax - xmin) / 14 for i in range(1, 10): pos = xcenter - i * tick_interval if pos < xmin + side_buffer: break s += draw_tick_mark(pos) for i in range(1, 10): pos = xcenter + i * tick_interval if pos > xmax - side_buffer: break s += draw_tick_mark(pos) s += draw_tick_mark(base_values, label="base value", backing=True) s += draw_tick_mark( fx, bold=True, label=f'f{output_name}(inputs)', backing=True ) ### Positive value marks ### red = (float(colors.red_rgb[0]) * 255, float(colors.red_rgb[1])* 255, float(colors.red_rgb[2])* 255) light_red = (255, 195, 213) # draw base red bar x = fx - values[values > 0].sum() w = 100 * values[values > 0].sum() / (xmax - xmin + 1e-8) s += f'' # draw underline marks and the text labels pos = fx last_pos = pos inds = [i for i in np.argsort(-np.abs(values)) if values[i] > 0] for i, ind in enumerate(inds): v = values[ind] pos -= v # a line under the bar to animate s += f'' # the text label cropped and centered s += f'{values[ind].round(3)}' # the text label cropped and centered s += f'' s += ' ' s += f' {tokens[ind].strip()}' s += " " s += "" last_pos = pos # draw the divider padding (which covers the text near the dividers) pos = fx for i, ind in enumerate(inds): v = values[ind] pos -= v if i != 0: for j in range(4): s += f'' s += f' ' s += f' ' s += " " s += "" if i + 1 != len(inds): for j in range(4): s += f'' s += f' ' s += f' ' s += " " s += "" last_pos = pos # center padding s += f'' # cover up a notch at the end of the red bar pos = fx - values[values > 0].sum() s += '' s += f' ' s += ' ' s += " " s += "" # draw the light red divider lines and a rect to handle mouseover events pos = fx last_pos = pos for i, ind in enumerate(inds): v = values[ind] pos -= v # divider line if i + 1 != len(inds): s += '' s += f' ' s += f' ' s += " " s += "" # mouse over rectangle s += f'' # draw underline marks and the text labels pos = fx last_pos = pos inds = [i for i in np.argsort(-np.abs(values)) if values[i] < 0] for i, ind in enumerate(inds): v = values[ind] pos -= v # a line under the bar to animate s += f'' # the value text s += f'{values[ind].round(3)}' # the text label cropped and centered s += f'' s += ' ' s += f' {tokens[ind].strip()}' s += " " s += "" last_pos = pos # draw the divider padding (which covers the text near the dividers) pos = fx for i, ind in enumerate(inds): v = values[ind] pos -= v if i != 0: for j in range(4): s += f'' s += f' ' s += f' ' s += " " s += "" if i + 1 != len(inds): for j in range(4): s += f'' s += f' ' s += f' ' s += " " s += "" last_pos = pos # center padding s += f'' # cover up a notch at the end of the blue bar pos = fx - values[values < 0].sum() s += '' s += f' ' s += ' ' s += " " s += "" # draw the light blue divider lines and a rect to handle mouseover events pos = fx last_pos = pos for i, ind in enumerate(inds): v = values[ind] pos -= v # divider line if i + 1 != len(inds): s += '' s += f' ' s += f' ' s += " " s += "" # mouse over rectangle s += f'= 0: new_tokens.append(tokens[i]) new_values.append(group_values[i]) group_sizes.append(1) else: # compute the dividend at internal nodes li = partition_tree[i - M, 0] ri = partition_tree[i - M, 1] dv = abs(shap_values[i]) / len(groups[i]) # if the interaction level is too high then just treat this whole group as one token if dv > grouping_threshold * max(max_values[li], max_values[ri]): new_tokens.append( separator.join([tokens[g] for g in groups[li]]) + separator + separator.join([tokens[g] for g in groups[ri]]) ) new_values.append(group_values[i] / len(groups[i])) group_sizes.append(len(groups[i])) # if interaction level is not too high we recurse else: merge_tokens(new_tokens, new_values, group_sizes, li) merge_tokens(new_tokens, new_values, group_sizes, ri) merge_tokens(new_tokens, new_shap_values, group_sizes, len(group_values) - 1) # replance the incoming parameters with the grouped versions tokens = np.array(new_tokens) shap_values = np.array(new_shap_values) group_sizes = np.array(group_sizes) M = len(tokens) else: group_sizes = np.ones(M) # build out HTML output one word one at a time top_inds = np.argsort(-np.abs(shap_values))[:num_starting_labels] maxv = shap_values.max() minv = shap_values.min() out = "" for i in range(M): scaled_value = 0.5 + 0.5 * shap_values[i] / max(abs(maxv), abs(minv)) color = colors.red_transparent_blue(scaled_value) color = (float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3])) # display the labels for the most important words label_display = "none" wrapper_display = "inline" if i in top_inds: label_display = "block" wrapper_display = "inline-block" # create the value_label string value_label = "" if group_sizes[i] == 1: value_label = str(shap_values[i].round(3)) else: value_label = str((shap_values[i] * group_sizes[i]).round(3)) + " / " + str(group_sizes[i]) # the HTML for this token out += ( "
" + "
" + value_label + "
" + "
" + tokens[i].replace("<", "<").replace(">", ">").replace(" ##", "") + "
" + "
" ) return _ipython_display_html(out) def text_to_text(shap_values): # unique ID added to HTML elements and function to avoid collision of different instances uuid = "".join(random.choices(string.ascii_lowercase, k=20)) saliency_plot_markup = saliency_plot(shap_values) heatmap_markup = heatmap(shap_values) html = f"""
Visualization Type:
{heatmap_markup}
""" javascript = f""" """ _ipython_display_html(javascript + html) def saliency_plot(shap_values): uuid = "".join(random.choices(string.ascii_lowercase, k=20)) unpacked_values, clustering = unpack_shap_explanation_contents(shap_values) tokens, values, group_sizes, token_id_to_node_id_mapping, collapsed_node_ids = process_shap_values( shap_values.data, unpacked_values[:, 0], 1, "", clustering, True ) def compress_shap_matrix(shap_matrix, group_sizes): compressed_matrix = np.zeros((group_sizes.shape[0], shap_matrix.shape[1])) counter = 0 for index in range(len(group_sizes)): compressed_matrix[index, :] = np.sum(shap_matrix[counter : counter + group_sizes[index], :], axis=0) counter += group_sizes[index] return compressed_matrix compressed_shap_matrix = compress_shap_matrix(shap_values.values, group_sizes) # generate background colors of saliency plot def get_colors(shap_values): input_colors = [] cmax = max(abs(compressed_shap_matrix.min()), abs(compressed_shap_matrix.max())) for row_index in range(compressed_shap_matrix.shape[0]): input_colors_row = [] for col_index in range(compressed_shap_matrix.shape[1]): scaled_value = 0.5 + 0.5 * compressed_shap_matrix[row_index, col_index] / cmax color = colors.red_transparent_blue(scaled_value) color = "rgba" + str((float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3]))) input_colors_row.append(color) input_colors.append(input_colors_row) return input_colors model_output = shap_values.output_names input_colors = get_colors(shap_values) out = '' # add top row containing input tokens out += "" out += "" for j in range(compressed_shap_matrix.shape[0]): out += ( "" ) out += "" for row_index in range(compressed_shap_matrix.shape[1]): out += "" out += ( "" ) for col_index in range(compressed_shap_matrix.shape[0]): out += ( '" ) out += "" out += "
" + tokens[j].replace("<", "<").replace(">", ">").replace(" ##", "").replace("▁", "").replace("Ġ", "") + "
" + model_output[row_index] .replace("<", "<") .replace(">", ">") .replace(" ##", "") .replace("▁", "") .replace("Ġ", "") + "' + str(round(compressed_shap_matrix[col_index][row_index], 3)) + "
" saliency_plot_html = f"""
Saliency Plot
x-axis: Output Text
y-axis: Input Text
{out}
""" return saliency_plot_html def heatmap(shap_values): # constants TREE_NODE_KEY_TOKENS = "tokens" TREE_NODE_KEY_CHILDREN = "children" uuid = "".join(random.choices(string.ascii_lowercase, k=20)) def get_color(shap_value, cmax): scaled_value = 0.5 + 0.5 * shap_value / cmax color = colors.red_transparent_blue(scaled_value) color = (float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3])) return color def process_text_to_text_shap_values(shap_values): processed_values = [] unpacked_values, clustering = unpack_shap_explanation_contents(shap_values) max_val = 0 for index, output_token in enumerate(shap_values.output_names): tokens, values, group_sizes, token_id_to_node_id_mapping, collapsed_node_ids = process_shap_values( shap_values.data, unpacked_values[:, index], 1, "", clustering, True ) processed_value = { "tokens": tokens, "values": values, "group_sizes": group_sizes, "token_id_to_node_id_mapping": token_id_to_node_id_mapping, "collapsed_node_ids": collapsed_node_ids, } processed_values.append(processed_value) max_val = max(max_val, np.max(values)) return processed_values, max_val # unpack input tokens and output tokens model_input = shap_values.data model_output = shap_values.output_names processed_values, max_val = process_text_to_text_shap_values(shap_values) # generate dictionary containing precomputed background colors and shap values which are addressable by html token ids colors_dict = {} shap_values_dict = {} token_id_to_node_id_mapping = {} cmax = max(abs(shap_values.values.min()), abs(shap_values.values.max()), max_val) # input token -> output token color and label value mapping for row_index in range(len(model_input)): color_values = {} shap_values_list = {} for col_index in range(len(model_output)): color_values[uuid + "_output_flat_token_" + str(col_index)] = "rgba" + str( get_color(shap_values.values[row_index][col_index], cmax) ) shap_values_list[uuid + "_output_flat_value_label_" + str(col_index)] = round( shap_values.values[row_index][col_index], 3 ) colors_dict[f"{uuid}_input_node_{row_index}_content"] = color_values shap_values_dict[f"{uuid}_input_node_{row_index}_content"] = shap_values_list # output token -> input token color and label value mapping for col_index in range(len(model_output)): color_values = {} shap_values_list = {} for row_index in range(processed_values[col_index]["collapsed_node_ids"].shape[0]): color_values[ uuid + "_input_node_" + str(processed_values[col_index]["collapsed_node_ids"][row_index]) + "_content" ] = "rgba" + str(get_color(processed_values[col_index]["values"][row_index], cmax)) shap_label_value_str = str(round(processed_values[col_index]["values"][row_index], 3)) if processed_values[col_index]["group_sizes"][row_index] > 1: shap_label_value_str += "/" + str(processed_values[col_index]["group_sizes"][row_index]) shap_values_list[ uuid + "_input_node_" + str(processed_values[col_index]["collapsed_node_ids"][row_index]) + "_label" ] = shap_label_value_str colors_dict[uuid + "_output_flat_token_" + str(col_index)] = color_values shap_values_dict[uuid + "_output_flat_token_" + str(col_index)] = shap_values_list token_id_to_node_id_mapping_dict = {} for index, node_id in enumerate(processed_values[col_index]["token_id_to_node_id_mapping"].tolist()): token_id_to_node_id_mapping_dict[f"{uuid}_input_node_{index}_content"] = ( f"{uuid}_input_node_{int(node_id)}_content" ) token_id_to_node_id_mapping[uuid + "_output_flat_token_" + str(col_index)] = token_id_to_node_id_mapping_dict # convert python dictionary into json to be inserted into the runtime javascript environment colors_json = json.dumps(colors_dict) shap_values_json = json.dumps(shap_values_dict) token_id_to_node_id_mapping_json = json.dumps(token_id_to_node_id_mapping) javascript_values = ( " \n " ) def generate_tree(shap_values): num_tokens = len(shap_values.data) token_list = {} for index in range(num_tokens): node_content = {} node_content[TREE_NODE_KEY_TOKENS] = shap_values.data[index] node_content[TREE_NODE_KEY_CHILDREN] = {} token_list[str(index)] = node_content counter = num_tokens for pair in shap_values.clustering: first_node = str(int(pair[0])) second_node = str(int(pair[1])) new_node_content = {} new_node_content[TREE_NODE_KEY_CHILDREN] = { first_node: token_list[first_node], second_node: token_list[second_node], } token_list[str(counter)] = new_node_content counter += 1 del token_list[first_node] del token_list[second_node] return token_list tree = generate_tree(shap_values) # generates the input token html elements # each element contains the label value (initially hidden) and the token text input_text_html = "" def populate_input_tree(input_index, token_list_subtree, input_text_html): content = token_list_subtree[input_index] input_text_html += ( f'
' ) input_text_html += ( f'" if token_list_subtree[input_index][TREE_NODE_KEY_CHILDREN]: input_text_html += f'
' for child_index, child_content in token_list_subtree[input_index][TREE_NODE_KEY_CHILDREN].items(): input_text_html = populate_input_tree( child_index, token_list_subtree[input_index][TREE_NODE_KEY_CHILDREN], input_text_html ) input_text_html += "
" else: input_text_html += ( f'
" ) input_text_html += ( content[TREE_NODE_KEY_TOKENS] .replace("<", "<") .replace(">", ">") .replace(" ##", "") .replace("▁", "") .replace("Ġ", "") ) input_text_html += "
" input_text_html += "
" return input_text_html input_text_html = populate_input_tree(list(tree.keys())[0], tree, input_text_html) # generates the output token html elements output_text_html = "" for i in range(len(model_output)): output_text_html += ( "
" f"
" "
" f"
" + model_output[i] .replace("<", "<") .replace(">", ">") .replace(" ##", "") .replace("▁", "") .replace("Ġ", "") + "
" + "
" ) heatmap_html = f"""
Input/Output - Heatmap
Layout :
Input Text
{input_text_html}
Output Text
{output_text_html}
""" heatmap_javascript = f""" """ return heatmap_html + heatmap_javascript + javascript_values def unpack_shap_explanation_contents(shap_values): values = getattr(shap_values, "hierarchical_values", None) if values is None: values = shap_values.values clustering = getattr(shap_values, "clustering", None) return np.array(values), clustering def _ipython_display_html(data): """Check IPython is installed, then display HTML""" if not have_ipython: msg = "IPython is required for this function but is not installed. Fix this with `pip install ipython`." raise ImportError(msg) return ipython_display(HTML(data))