marcossuzuki commited on
Commit
e11ec65
·
verified ·
1 Parent(s): 760fbb2

Fix colors from plots.text

Browse files
Files changed (1) hide show
  1. _text.py +1465 -0
_text.py ADDED
@@ -0,0 +1,1465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import string
4
+ import warnings
5
+
6
+ import numpy as np
7
+
8
+ from . import colors
9
+
10
+ try:
11
+ from IPython.display import HTML
12
+ from IPython.display import display as ipython_display
13
+
14
+ have_ipython = True
15
+ except ImportError:
16
+ have_ipython = False
17
+
18
+
19
+ # TODO: we should support text output explanations (from models that output text not numbers), this would require the force
20
+ # the force plot and the coloring to update based on mouseovers (or clicks to make it fixed) of the output text
21
+ def text(
22
+ shap_values,
23
+ num_starting_labels=0,
24
+ grouping_threshold=0.01,
25
+ separator="",
26
+ xmin=None,
27
+ xmax=None,
28
+ cmax=None,
29
+ display=True,
30
+ ):
31
+ """Plots an explanation of a string of text using coloring and interactive labels.
32
+
33
+ The output is interactive HTML and you can click on any token to toggle the display of the
34
+ SHAP value assigned to that token.
35
+
36
+ Parameters
37
+ ----------
38
+ shap_values : [numpy.array]
39
+ List of arrays of SHAP values. Each array has the shap values for a string (#input_tokens x output_tokens).
40
+
41
+ num_starting_labels : int
42
+ Number of tokens (sorted in descending order by corresponding SHAP values)
43
+ that are uncovered in the initial view.
44
+ When set to 0, all tokens are covered.
45
+
46
+ grouping_threshold : float
47
+ If the component substring effects are less than a ``grouping_threshold``
48
+ fraction of an unlowered interaction effect, then we visualize the entire group
49
+ as a single chunk. This is primarily used for explanations that were computed
50
+ with fixed_context set to 1 or 0 when using the :class:`.explainers.Partition`
51
+ explainer, since this causes interaction effects to be left on internal nodes
52
+ rather than lowered.
53
+
54
+ separator : string
55
+ The string separator that joins tokens grouped by interaction effects and
56
+ unbroken string spans. Defaults to the empty string ``""``.
57
+
58
+ xmin : float
59
+ Minimum shap value bound.
60
+
61
+ xmax : float
62
+ Maximum shap value bound.
63
+
64
+ cmax : float
65
+ Maximum absolute shap value for sample. Used for scaling colors for input tokens.
66
+
67
+ display: bool
68
+ Whether to display or return html to further manipulate or embed. Default: ``True``
69
+
70
+ Examples
71
+ --------
72
+ See `text plot examples <https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/text.html>`_.
73
+
74
+ """
75
+
76
+ def values_min_max(values, base_values):
77
+ """Used to pick our axis limits."""
78
+ fx = base_values + values.sum()
79
+ xmin = fx - values[values > 0].sum()
80
+ xmax = fx - values[values < 0].sum()
81
+ cmax = max(abs(values.min()), abs(values.max()))
82
+ d = xmax - xmin
83
+ xmin -= 0.1 * d
84
+ xmax += 0.1 * d
85
+
86
+ return xmin, xmax, cmax
87
+
88
+ uuid = "".join(random.choices(string.ascii_lowercase, k=20))
89
+
90
+ # loop when we get multi-row inputs
91
+ if len(shap_values.shape) == 2 and (shap_values.output_names is None or isinstance(shap_values.output_names, str)):
92
+ xmin = 0
93
+ xmax = 0
94
+ cmax = 0
95
+
96
+ for i, v in enumerate(shap_values):
97
+ values, clustering = unpack_shap_explanation_contents(v)
98
+ tokens, values, group_sizes = process_shap_values(v.data, values, grouping_threshold, separator, clustering)
99
+
100
+ if i == 0:
101
+ xmin, xmax, cmax = values_min_max(values, v.base_values)
102
+ continue
103
+
104
+ xmin_i, xmax_i, cmax_i = values_min_max(values, v.base_values)
105
+ if xmin_i < xmin:
106
+ xmin = xmin_i
107
+ if xmax_i > xmax:
108
+ xmax = xmax_i
109
+ if cmax_i > cmax:
110
+ cmax = cmax_i
111
+ out = ""
112
+ for i, v in enumerate(shap_values):
113
+ out += f"""
114
+ <br>
115
+ <hr style="height: 1px; background-color: #fff; border: none; margin-top: 18px; margin-bottom: 18px; border-top: 1px dashed #ccc;"">
116
+ <div align="center" style="margin-top: -35px;"><div style="display: inline-block; background: #fff; padding: 5px; color: #999; font-family: monospace">[{i}]</div>
117
+ </div>
118
+ """
119
+ out += text(
120
+ v,
121
+ num_starting_labels=num_starting_labels,
122
+ grouping_threshold=grouping_threshold,
123
+ separator=separator,
124
+ xmin=xmin,
125
+ xmax=xmax,
126
+ cmax=cmax,
127
+ display=False,
128
+ )
129
+ if display:
130
+ _ipython_display_html(out)
131
+ return
132
+ else:
133
+ return out
134
+
135
+ if len(shap_values.shape) == 2 and shap_values.output_names is not None:
136
+ xmin_computed = None
137
+ xmax_computed = None
138
+ cmax_computed = None
139
+
140
+ for i in range(shap_values.shape[-1]):
141
+ values, clustering = unpack_shap_explanation_contents(shap_values[:, i])
142
+ tokens, values, group_sizes = process_shap_values(
143
+ shap_values[:, i].data, values, grouping_threshold, separator, clustering
144
+ )
145
+
146
+ # if i == 0:
147
+ # xmin, xmax, cmax = values_min_max(values, shap_values[:,i].base_values)
148
+ # continue
149
+
150
+ xmin_i, xmax_i, cmax_i = values_min_max(values, shap_values[:, i].base_values)
151
+ if xmin_computed is None or xmin_i < xmin_computed:
152
+ xmin_computed = xmin_i
153
+ if xmax_computed is None or xmax_i > xmax_computed:
154
+ xmax_computed = xmax_i
155
+ if cmax_computed is None or cmax_i > cmax_computed:
156
+ cmax_computed = cmax_i
157
+
158
+ if xmin is None:
159
+ xmin = xmin_computed
160
+ if xmax is None:
161
+ xmax = xmax_computed
162
+ if cmax is None:
163
+ cmax = cmax_computed
164
+
165
+ out = f"""<div align='center'>
166
+ <script>
167
+ document._hover_{uuid} = '_tp_{uuid}_output_0';
168
+ document._zoom_{uuid} = undefined;
169
+ function _output_onclick_{uuid}(i) {{
170
+ var next_id = undefined;
171
+
172
+ if (document._zoom_{uuid} !== undefined) {{
173
+ document.getElementById(document._zoom_{uuid}+ '_zoom').style.display = 'none';
174
+
175
+ if (document._zoom_{uuid} === '_tp_{uuid}_output_' + i) {{
176
+ document.getElementById(document._zoom_{uuid}).style.display = 'block';
177
+ document.getElementById(document._zoom_{uuid}+'_name').style.borderBottom = '3px solid #000000';
178
+ }} else {{
179
+ document.getElementById(document._zoom_{uuid}).style.display = 'none';
180
+ document.getElementById(document._zoom_{uuid}+'_name').style.borderBottom = 'none';
181
+ }}
182
+ }}
183
+ if (document._zoom_{uuid} !== '_tp_{uuid}_output_' + i) {{
184
+ next_id = '_tp_{uuid}_output_' + i;
185
+ document.getElementById(next_id).style.display = 'none';
186
+ document.getElementById(next_id + '_zoom').style.display = 'block';
187
+ document.getElementById(next_id+'_name').style.borderBottom = '3px solid #000000';
188
+ }}
189
+ document._zoom_{uuid} = next_id;
190
+ }}
191
+ function _output_onmouseover_{uuid}(i, el) {{
192
+ if (document._zoom_{uuid} !== undefined) {{ return; }}
193
+ if (document._hover_{uuid} !== undefined) {{
194
+ document.getElementById(document._hover_{uuid} + '_name').style.borderBottom = 'none';
195
+ document.getElementById(document._hover_{uuid}).style.display = 'none';
196
+ }}
197
+ document.getElementById('_tp_{uuid}_output_' + i).style.display = 'block';
198
+ el.style.borderBottom = '3px solid #000000';
199
+ document._hover_{uuid} = '_tp_{uuid}_output_' + i;
200
+ }}
201
+ </script>
202
+ <div style=\"color: rgb(120,120,120); font-size: 12px;\">outputs</div>"""
203
+ output_values = shap_values.values.sum(0) + shap_values.base_values
204
+ output_max = np.max(np.abs(output_values))
205
+ for i, name in enumerate(shap_values.output_names):
206
+ scaled_value = 0.5 + 0.5 * float(output_values[i]) / (float(output_max) + 1e-8)
207
+ color = colors.red_transparent_blue(scaled_value)
208
+ color = (float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3]))
209
+ # '#dddddd' if i == 0 else '#ffffff' border-bottom: {'3px solid #000000' if i == 0 else 'none'};
210
+ out += f"""
211
+ <div style="display: inline; border-bottom: {"3px solid #000000" if i == 0 else "none"}; background: rgba{color}; border-radius: 3px; padding: 0px" id="_tp_{uuid}_output_{i}_name"
212
+ onclick="_output_onclick_{uuid}({i})"
213
+ onmouseover="_output_onmouseover_{uuid}({i}, this);">{name}</div>"""
214
+ out += "<br><br>"
215
+ for i, name in enumerate(shap_values.output_names):
216
+ out += f"<div id='_tp_{uuid}_output_{i}' style='display: {'block' if i == 0 else 'none'}';>"
217
+ out += text(
218
+ shap_values[:, i],
219
+ num_starting_labels=num_starting_labels,
220
+ grouping_threshold=grouping_threshold,
221
+ separator=separator,
222
+ xmin=xmin,
223
+ xmax=xmax,
224
+ cmax=cmax,
225
+ display=False,
226
+ )
227
+ out += "</div>"
228
+ out += f"<div id='_tp_{uuid}_output_{i}_zoom' style='display: none;'>"
229
+ out += text(
230
+ shap_values[:, i],
231
+ num_starting_labels=num_starting_labels,
232
+ grouping_threshold=grouping_threshold,
233
+ separator=separator,
234
+ display=False,
235
+ )
236
+ out += "</div>"
237
+ out += "</div>"
238
+ if display:
239
+ _ipython_display_html(out)
240
+ return
241
+ else:
242
+ return out
243
+ # text_to_text(shap_values)
244
+ # return
245
+
246
+ if len(shap_values.shape) == 3:
247
+ xmin_computed = None
248
+ xmax_computed = None
249
+ cmax_computed = None
250
+
251
+ for i in range(shap_values.shape[-1]):
252
+ for j in range(shap_values.shape[0]):
253
+ values, clustering = unpack_shap_explanation_contents(shap_values[j, :, i])
254
+ tokens, values, group_sizes = process_shap_values(
255
+ shap_values[j, :, i].data, values, grouping_threshold, separator, clustering
256
+ )
257
+
258
+ xmin_i, xmax_i, cmax_i = values_min_max(values, shap_values[j, :, i].base_values)
259
+ if xmin_computed is None or xmin_i < xmin_computed:
260
+ xmin_computed = xmin_i
261
+ if xmax_computed is None or xmax_i > xmax_computed:
262
+ xmax_computed = xmax_i
263
+ if cmax_computed is None or cmax_i > cmax_computed:
264
+ cmax_computed = cmax_i
265
+
266
+ if xmin is None:
267
+ xmin = xmin_computed
268
+ if xmax is None:
269
+ xmax = xmax_computed
270
+ if cmax is None:
271
+ cmax = cmax_computed
272
+
273
+ out = ""
274
+ for i, v in enumerate(shap_values):
275
+ out += f"""
276
+ <br>
277
+ <hr style="height: 1px; background-color: #fff; border: none; margin-top: 18px; margin-bottom: 18px; border-top: 1px dashed #ccc;"">
278
+ <div align="center" style="margin-top: -35px;"><div style="display: inline-block; background: #fff; padding: 5px; color: #999; font-family: monospace">[{i}]</div>
279
+ </div>
280
+ """
281
+ out += text(
282
+ v,
283
+ num_starting_labels=num_starting_labels,
284
+ grouping_threshold=grouping_threshold,
285
+ separator=separator,
286
+ xmin=xmin,
287
+ xmax=xmax,
288
+ cmax=cmax,
289
+ display=False,
290
+ )
291
+ if display:
292
+ _ipython_display_html(out)
293
+ return
294
+ else:
295
+ return out
296
+
297
+ # set any unset bounds
298
+ xmin_new, xmax_new, cmax_new = values_min_max(shap_values.values, shap_values.base_values)
299
+ if xmin is None:
300
+ xmin = xmin_new
301
+ if xmax is None:
302
+ xmax = xmax_new
303
+ if cmax is None:
304
+ cmax = cmax_new
305
+
306
+ values, clustering = unpack_shap_explanation_contents(shap_values)
307
+ tokens, values, group_sizes = process_shap_values(
308
+ shap_values.data, values, grouping_threshold, separator, clustering
309
+ )
310
+
311
+ # build out HTML output one word one at a time
312
+ top_inds = np.argsort(-np.abs(values))[:num_starting_labels]
313
+ out = ""
314
+ # ev_str = str(shap_values.base_values)
315
+ # vsum_str = str(values.sum())
316
+ # fx_str = str(shap_values.base_values + values.sum())
317
+
318
+ # uuid = ''.join(random.choices(string.ascii_lowercase, k=20))
319
+ encoded_tokens = [t.replace("<", "&lt;").replace(">", "&gt;").replace(" ##", "") for t in tokens]
320
+ output_name = shap_values.output_names if isinstance(shap_values.output_names, str) else ""
321
+ out += svg_force_plot(
322
+ values,
323
+ shap_values.base_values,
324
+ shap_values.base_values + values.sum(),
325
+ encoded_tokens,
326
+ uuid,
327
+ xmin,
328
+ xmax,
329
+ output_name,
330
+ )
331
+ out += (
332
+ "<div align='center'><div style=\"color: rgb(120,120,120); font-size: 12px; margin-top: -15px;\">inputs</div>"
333
+ )
334
+ for i, token in enumerate(tokens):
335
+ scaled_value = 0.5 + 0.5 * values[i] / (cmax + 1e-8)
336
+ color = colors.red_transparent_blue(scaled_value)
337
+ color = (float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3]))
338
+
339
+ # display the labels for the most important words
340
+ label_display = "none"
341
+ wrapper_display = "inline"
342
+ if i in top_inds:
343
+ label_display = "block"
344
+ wrapper_display = "inline-block"
345
+
346
+ # create the value_label string
347
+ value_label = ""
348
+ if group_sizes[i] == 1:
349
+ value_label = str(values[i].round(3))
350
+ else:
351
+ value_label = str(values[i].round(3)) + " / " + str(group_sizes[i])
352
+
353
+ # the HTML for this token
354
+ out += f"""<div style='display: {wrapper_display}; text-align: center;'
355
+ ><div style='display: {label_display}; color: #999; padding-top: 0px; font-size: 12px;'>{value_label}</div
356
+ ><div id='_tp_{uuid}_ind_{i}'
357
+ style='display: inline; background: rgba{color}; border-radius: 3px; padding: 0px'
358
+ onclick="
359
+ if (this.previousSibling.style.display == 'none') {{
360
+ this.previousSibling.style.display = 'block';
361
+ this.parentNode.style.display = 'inline-block';
362
+ }} else {{
363
+ this.previousSibling.style.display = 'none';
364
+ this.parentNode.style.display = 'inline';
365
+ }}"
366
+ onmouseover="document.getElementById('_fb_{uuid}_ind_{i}').style.opacity = 1; document.getElementById('_fs_{uuid}_ind_{i}').style.opacity = 1;"
367
+ onmouseout="document.getElementById('_fb_{uuid}_ind_{i}').style.opacity = 0; document.getElementById('_fs_{uuid}_ind_{i}').style.opacity = 0;"
368
+ >{token.replace("<", "&lt;").replace(">", "&gt;").replace(" ##", "")}</div></div>"""
369
+ out += "</div>"
370
+
371
+ if display:
372
+ _ipython_display_html(out)
373
+ return
374
+ else:
375
+ return out
376
+
377
+
378
+ def process_shap_values(tokens, values, grouping_threshold, separator, clustering=None, return_meta_data=False):
379
+ # See if we got hierarchical input data. If we did then we need to reprocess the
380
+ # shap_values and tokens to get the groups we want to display
381
+ M = len(tokens)
382
+ if len(values) != M:
383
+ # make sure we were given a partition tree
384
+ if clustering is None:
385
+ raise ValueError(
386
+ "The length of the attribution values must match the number of "
387
+ "tokens if shap_values.clustering is None! When passing hierarchical "
388
+ "attributions the clustering is also required."
389
+ )
390
+
391
+ # compute the groups, lower_values, and max_values
392
+ groups = [[i] for i in range(M)]
393
+ lower_values = np.zeros(len(values))
394
+ lower_values[:M] = values[:M]
395
+ max_values = np.zeros(len(values))
396
+ max_values[:M] = np.abs(values[:M])
397
+ for i in range(clustering.shape[0]):
398
+ li = int(clustering[i, 0])
399
+ ri = int(clustering[i, 1])
400
+ groups.append(groups[li] + groups[ri])
401
+ lower_values[M + i] = lower_values[li] + lower_values[ri] + values[M + i]
402
+ max_values[i + M] = max(abs(values[M + i]) / len(groups[M + i]), max_values[li], max_values[ri])
403
+
404
+ # compute the upper_values
405
+ upper_values = np.zeros(len(values))
406
+
407
+ def lower_credit(upper_values, clustering, i, value=0):
408
+ if i < M:
409
+ upper_values[i] = value
410
+ return
411
+ li = int(clustering[i - M, 0])
412
+ ri = int(clustering[i - M, 1])
413
+ upper_values[i] = value
414
+ value += values[i]
415
+ # lower_credit(upper_values, clustering, li, value * len(groups[li]) / (len(groups[li]) + len(groups[ri])))
416
+ # lower_credit(upper_values, clustering, ri, value * len(groups[ri]) / (len(groups[li]) + len(groups[ri])))
417
+ lower_credit(upper_values, clustering, li, value * 0.5)
418
+ lower_credit(upper_values, clustering, ri, value * 0.5)
419
+
420
+ lower_credit(upper_values, clustering, len(values) - 1)
421
+
422
+ # the group_values comes from the dividends above them and below them
423
+ group_values = lower_values + upper_values
424
+
425
+ # merge all the tokens in groups dominated by interaction effects (since we don't want to hide those)
426
+ new_tokens = []
427
+ new_values = []
428
+ group_sizes = []
429
+
430
+ # meta data
431
+ token_id_to_node_id_mapping = np.zeros((M,))
432
+ collapsed_node_ids = []
433
+
434
+ def merge_tokens(new_tokens, new_values, group_sizes, i):
435
+ # return at the leaves
436
+ if i < M and i >= 0:
437
+ new_tokens.append(tokens[i])
438
+ new_values.append(group_values[i])
439
+ group_sizes.append(1)
440
+
441
+ # meta data
442
+ collapsed_node_ids.append(i)
443
+ token_id_to_node_id_mapping[i] = i
444
+
445
+ else:
446
+ # compute the dividend at internal nodes
447
+ li = int(clustering[i - M, 0])
448
+ ri = int(clustering[i - M, 1])
449
+ dv = abs(values[i]) / len(groups[i])
450
+
451
+ # if the interaction level is too high then just treat this whole group as one token
452
+ if max(max_values[li], max_values[ri]) < dv * grouping_threshold:
453
+ new_tokens.append(
454
+ separator.join([tokens[g] for g in groups[li]])
455
+ + separator
456
+ + separator.join([tokens[g] for g in groups[ri]])
457
+ )
458
+ new_values.append(group_values[i])
459
+ group_sizes.append(len(groups[i]))
460
+
461
+ # setting collapsed node ids and token id to current node id mapping metadata
462
+
463
+ collapsed_node_ids.append(i)
464
+ for g in groups[li]:
465
+ token_id_to_node_id_mapping[g] = i
466
+
467
+ for g in groups[ri]:
468
+ token_id_to_node_id_mapping[g] = i
469
+
470
+ # if interaction level is not too high we recurse
471
+ else:
472
+ merge_tokens(new_tokens, new_values, group_sizes, li)
473
+ merge_tokens(new_tokens, new_values, group_sizes, ri)
474
+
475
+ merge_tokens(new_tokens, new_values, group_sizes, len(group_values) - 1)
476
+
477
+ # replance the incoming parameters with the grouped versions
478
+ tokens = np.array(new_tokens)
479
+ values = np.array(new_values)
480
+ group_sizes = np.array(group_sizes)
481
+
482
+ # meta data
483
+ token_id_to_node_id_mapping = np.array(token_id_to_node_id_mapping)
484
+ collapsed_node_ids = np.array(collapsed_node_ids)
485
+
486
+ M = len(tokens)
487
+ else:
488
+ group_sizes = np.ones(M)
489
+ token_id_to_node_id_mapping = np.arange(M)
490
+ collapsed_node_ids = np.arange(M)
491
+
492
+ if return_meta_data:
493
+ return tokens, values, group_sizes, token_id_to_node_id_mapping, collapsed_node_ids
494
+ else:
495
+ return tokens, values, group_sizes
496
+
497
+
498
+ def svg_force_plot(values, base_values, fx, tokens, uuid, xmin, xmax, output_name):
499
+ def xpos(xval):
500
+ return 100 * (xval - xmin) / (xmax - xmin + 1e-8)
501
+
502
+ s = ""
503
+ s += '<svg width="100%" height="80px">'
504
+
505
+ ### x-axis marks ###
506
+
507
+ # draw x axis line
508
+ s += '<line x1="0" y1="33" x2="100%" y2="33" style="stroke:rgb(150,150,150);stroke-width:1" />'
509
+
510
+ # draw base value
511
+ def draw_tick_mark(xval, label=None, bold=False, backing=False):
512
+ s = ""
513
+ s += f'<line x1="{xpos(xval)}%" y1="33" x2="{xpos(xval)}%" y2="37" style="stroke:rgb(150,150,150);stroke-width:1" />'
514
+ if not bold:
515
+ if backing:
516
+ s += f'<text x="{xpos(xval)}%" y="27" font-size="13px" style="stroke:#ffffff;stroke-width:8px;" fill="rgb(255,255,255)" dominant-baseline="bottom" text-anchor="middle">{xval:g}</text>'
517
+ s += f'<text x="{xpos(xval)}%" y="27" font-size="12px" fill="rgb(120,120,120)" dominant-baseline="bottom" text-anchor="middle">{xval:g}</text>'
518
+ else:
519
+ if backing:
520
+ s += f'<text x="{xpos(xval)}%" y="27" font-size="13px" style="stroke:#ffffff;stroke-width:8px;" font-weight="bold" fill="rgb(255,255,255)" dominant-baseline="bottom" text-anchor="middle">{xval:g}</text>'
521
+ s += f'<text x="{xpos(xval)}%" y="27" font-size="13px" font-weight="bold" fill="rgb(0,0,0)" dominant-baseline="bottom" text-anchor="middle">{xval:g}</text>'
522
+ if label is not None:
523
+ s += f'<text x="{xpos(xval)}%" y="10" font-size="12px" fill="rgb(120,120,120)" dominant-baseline="bottom" text-anchor="middle">{label}</text>'
524
+ return s
525
+
526
+ xcenter = round((xmax + xmin) / 2, int(round(1 - np.log10(xmax - xmin + 1e-8))))
527
+ s += draw_tick_mark(xcenter)
528
+ # np.log10(xmax - xmin)
529
+
530
+ tick_interval = round((xmax - xmin) / 7, int(round(1 - np.log10(xmax - xmin + 1e-8))))
531
+
532
+ # tick_interval = (xmax - xmin) / 7
533
+ side_buffer = (xmax - xmin) / 14
534
+ for i in range(1, 10):
535
+ pos = xcenter - i * tick_interval
536
+ if pos < xmin + side_buffer:
537
+ break
538
+ s += draw_tick_mark(pos)
539
+ for i in range(1, 10):
540
+ pos = xcenter + i * tick_interval
541
+ if pos > xmax - side_buffer:
542
+ break
543
+ s += draw_tick_mark(pos)
544
+ s += draw_tick_mark(base_values, label="base value", backing=True)
545
+ s += draw_tick_mark(
546
+ fx, bold=True, label=f'f<tspan baseline-shift="sub" font-size="8px">{output_name}</tspan>(inputs)', backing=True
547
+ )
548
+
549
+ ### Positive value marks ###
550
+
551
+ red = (float(colors.red_rgb[0]) * 255, float(colors.red_rgb[1])* 255, float(colors.red_rgb[2])* 255)
552
+ light_red = (255, 195, 213)
553
+
554
+ # draw base red bar
555
+ x = fx - values[values > 0].sum()
556
+ w = 100 * values[values > 0].sum() / (xmax - xmin + 1e-8)
557
+ s += f'<rect x="{xpos(x)}%" width="{w}%" y="40" height="18" style="fill:rgb{red}; stroke-width:0; stroke:rgb(0,0,0)" />'
558
+
559
+ # draw underline marks and the text labels
560
+ pos = fx
561
+ last_pos = pos
562
+ inds = [i for i in np.argsort(-np.abs(values)) if values[i] > 0]
563
+ for i, ind in enumerate(inds):
564
+ v = values[ind]
565
+ pos -= v
566
+
567
+ # a line under the bar to animate
568
+ s += f'<line x1="{xpos(pos)}%" x2="{xpos(last_pos)}%" y1="60" y2="60" id="_fb_{uuid}_ind_{ind}" style="stroke:rgb{red};stroke-width:2; opacity: 0"/>'
569
+
570
+ # the text label cropped and centered
571
+ s += f'<text x="{(xpos(last_pos) + xpos(pos)) / 2}%" y="71" font-size="12px" id="_fs_{uuid}_ind_{ind}" fill="rgb{red}" style="opacity: 0" dominant-baseline="middle" text-anchor="middle">{values[ind].round(3)}</text>'
572
+
573
+ # the text label cropped and centered
574
+ s += f'<svg x="{xpos(pos)}%" y="40" height="20" width="{xpos(last_pos) - xpos(pos)}%">'
575
+ s += ' <svg x="0" y="0" width="100%" height="100%">'
576
+ s += f' <text x="50%" y="9" font-size="12px" fill="rgb(255,255,255)" dominant-baseline="middle" text-anchor="middle">{tokens[ind].strip()}</text>'
577
+ s += " </svg>"
578
+ s += "</svg>"
579
+
580
+ last_pos = pos
581
+
582
+ # draw the divider padding (which covers the text near the dividers)
583
+ pos = fx
584
+ for i, ind in enumerate(inds):
585
+ v = values[ind]
586
+ pos -= v
587
+
588
+ if i != 0:
589
+ for j in range(4):
590
+ s += f'<g transform="translate({2 * j - 8},0)">'
591
+ s += f' <svg x="{xpos(last_pos)}%" y="40" height="18" overflow="visible" width="30">'
592
+ s += f' <path d="M 0 -9 l 6 18 L 0 25" fill="none" style="stroke:rgb{red};stroke-width:2" />'
593
+ s += " </svg>"
594
+ s += "</g>"
595
+
596
+ if i + 1 != len(inds):
597
+ for j in range(4):
598
+ s += f'<g transform="translate({2 * j - 0},0)">'
599
+ s += f' <svg x="{xpos(pos)}%" y="40" height="18" overflow="visible" width="30">'
600
+ s += f' <path d="M 0 -9 l 6 18 L 0 25" fill="none" style="stroke:rgb{red};stroke-width:2" />'
601
+ s += " </svg>"
602
+ s += "</g>"
603
+
604
+ last_pos = pos
605
+
606
+ # center padding
607
+ s += f'<rect transform="translate(-8,0)" x="{xpos(fx)}%" y="40" width="8" height="18" style="fill:rgb{red}"/>'
608
+
609
+ # cover up a notch at the end of the red bar
610
+ pos = fx - values[values > 0].sum()
611
+ s += '<g transform="translate(-11.5,0)">'
612
+ s += f' <svg x="{xpos(pos)}%" y="40" height="18" overflow="visible" width="30">'
613
+ s += ' <path d="M 10 -9 l 6 18 L 10 25 L 0 25 L 0 -9" fill="#ffffff" style="stroke:rgb(255,255,255);stroke-width:2" />'
614
+ s += " </svg>"
615
+ s += "</g>"
616
+
617
+ # draw the light red divider lines and a rect to handle mouseover events
618
+ pos = fx
619
+ last_pos = pos
620
+ for i, ind in enumerate(inds):
621
+ v = values[ind]
622
+ pos -= v
623
+
624
+ # divider line
625
+ if i + 1 != len(inds):
626
+ s += '<g transform="translate(-1.5,0)">'
627
+ s += f' <svg x="{xpos(last_pos)}%" y="40" height="18" overflow="visible" width="30">'
628
+ s += f' <path d="M 0 -9 l 6 18 L 0 25" fill="none" style="stroke:rgb{light_red};stroke-width:2" />'
629
+ s += " </svg>"
630
+ s += "</g>"
631
+
632
+ # mouse over rectangle
633
+ s += f'<rect x="{xpos(pos)}%" y="40" height="20" width="{xpos(last_pos) - xpos(pos)}%"'
634
+ s += ' onmouseover="'
635
+ s += f"document.getElementById('_tp_{uuid}_ind_{ind}').style.textDecoration = 'underline';"
636
+ s += f"document.getElementById('_fs_{uuid}_ind_{ind}').style.opacity = 1;"
637
+ s += f"document.getElementById('_fb_{uuid}_ind_{ind}').style.opacity = 1;"
638
+ s += '"'
639
+ s += ' onmouseout="'
640
+ s += f"document.getElementById('_tp_{uuid}_ind_{ind}').style.textDecoration = 'none';"
641
+ s += f"document.getElementById('_fs_{uuid}_ind_{ind}').style.opacity = 0;"
642
+ s += f"document.getElementById('_fb_{uuid}_ind_{ind}').style.opacity = 0;"
643
+ s += '" style="fill:rgb(0,0,0,0)" />'
644
+
645
+ last_pos = pos
646
+
647
+ ### Negative value marks ###
648
+
649
+ blue = (float(colors.blue_rgb[0]) * 255, float(colors.blue_rgb[1]) * 255, float(colors.blue_rgb[2]) * 255)
650
+ light_blue = (208, 230, 250)
651
+
652
+ # draw base blue bar
653
+ w = 100 * -values[values < 0].sum() / (xmax - xmin + 1e-8)
654
+ s += f'<rect x="{xpos(fx)}%" width="{w}%" y="40" height="18" style="fill:rgb{blue}; stroke-width:0; stroke:rgb(0,0,0)" />'
655
+
656
+ # draw underline marks and the text labels
657
+ pos = fx
658
+ last_pos = pos
659
+ inds = [i for i in np.argsort(-np.abs(values)) if values[i] < 0]
660
+ for i, ind in enumerate(inds):
661
+ v = values[ind]
662
+ pos -= v
663
+
664
+ # a line under the bar to animate
665
+ s += f'<line x1="{xpos(last_pos)}%" x2="{xpos(pos)}%" y1="60" y2="60" id="_fb_{uuid}_ind_{ind}" style="stroke:rgb{blue};stroke-width:2; opacity: 0"/>'
666
+
667
+ # the value text
668
+ s += f'<text x="{(xpos(last_pos) + xpos(pos)) / 2}%" y="71" font-size="12px" fill="rgb{blue}" id="_fs_{uuid}_ind_{ind}" style="opacity: 0" dominant-baseline="middle" text-anchor="middle">{values[ind].round(3)}</text>'
669
+
670
+ # the text label cropped and centered
671
+ s += f'<svg x="{xpos(last_pos)}%" y="40" height="20" width="{xpos(pos) - xpos(last_pos)}%">'
672
+ s += ' <svg x="0" y="0" width="100%" height="100%">'
673
+ s += f' <text x="50%" y="9" font-size="12px" fill="rgb(255,255,255)" dominant-baseline="middle" text-anchor="middle">{tokens[ind].strip()}</text>'
674
+ s += " </svg>"
675
+ s += "</svg>"
676
+
677
+ last_pos = pos
678
+
679
+ # draw the divider padding (which covers the text near the dividers)
680
+ pos = fx
681
+ for i, ind in enumerate(inds):
682
+ v = values[ind]
683
+ pos -= v
684
+
685
+ if i != 0:
686
+ for j in range(4):
687
+ s += f'<g transform="translate({-2 * j + 2},0)">'
688
+ s += f' <svg x="{xpos(last_pos)}%" y="40" height="18" overflow="visible" width="30">'
689
+ s += f' <path d="M 8 -9 l -6 18 L 8 25" fill="none" style="stroke:rgb{blue};stroke-width:2" />'
690
+ s += " </svg>"
691
+ s += "</g>"
692
+
693
+ if i + 1 != len(inds):
694
+ for j in range(4):
695
+ s += f'<g transform="translate(-{2 * j + 8},0)">'
696
+ s += f' <svg x="{xpos(pos)}%" y="40" height="18" overflow="visible" width="30">'
697
+ s += f' <path d="M 8 -9 l -6 18 L 8 25" fill="none" style="stroke:rgb{blue};stroke-width:2" />'
698
+ s += " </svg>"
699
+ s += "</g>"
700
+
701
+ last_pos = pos
702
+
703
+ # center padding
704
+ s += f'<rect transform="translate(0,0)" x="{xpos(fx)}%" y="40" width="8" height="18" style="fill:rgb{blue}"/>'
705
+
706
+ # cover up a notch at the end of the blue bar
707
+ pos = fx - values[values < 0].sum()
708
+ s += '<g transform="translate(-6.0,0)">'
709
+ s += f' <svg x="{xpos(pos)}%" y="40" height="18" overflow="visible" width="30">'
710
+ s += ' <path d="M 8 -9 l -6 18 L 8 25 L 20 25 L 20 -9" fill="#ffffff" style="stroke:rgb(255,255,255);stroke-width:2" />'
711
+ s += " </svg>"
712
+ s += "</g>"
713
+
714
+ # draw the light blue divider lines and a rect to handle mouseover events
715
+ pos = fx
716
+ last_pos = pos
717
+ for i, ind in enumerate(inds):
718
+ v = values[ind]
719
+ pos -= v
720
+
721
+ # divider line
722
+ if i + 1 != len(inds):
723
+ s += '<g transform="translate(-6.0,0)">'
724
+ s += f' <svg x="{xpos(pos)}%" y="40" height="18" overflow="visible" width="30">'
725
+ s += f' <path d="M 8 -9 l -6 18 L 8 25" fill="none" style="stroke:rgb{light_blue};stroke-width:2" />'
726
+ s += " </svg>"
727
+ s += "</g>"
728
+
729
+ # mouse over rectangle
730
+ s += f'<rect x="{xpos(last_pos)}%" y="40" height="20" width="{xpos(pos) - xpos(last_pos)}%"'
731
+ s += ' onmouseover="'
732
+ s += f"document.getElementById('_tp_{uuid}_ind_{ind}').style.textDecoration = 'underline';"
733
+ s += f"document.getElementById('_fs_{uuid}_ind_{ind}').style.opacity = 1;"
734
+ s += f"document.getElementById('_fb_{uuid}_ind_{ind}').style.opacity = 1;"
735
+ s += '"'
736
+ s += ' onmouseout="'
737
+ s += f"document.getElementById('_tp_{uuid}_ind_{ind}').style.textDecoration = 'none';"
738
+ s += f"document.getElementById('_fs_{uuid}_ind_{ind}').style.opacity = 0;"
739
+ s += f"document.getElementById('_fb_{uuid}_ind_{ind}').style.opacity = 0;"
740
+ s += '" style="fill:rgb(0,0,0,0)" />'
741
+
742
+ last_pos = pos
743
+
744
+ s += "</svg>"
745
+
746
+ return s
747
+
748
+
749
+ def text_old(shap_values, tokens, partition_tree=None, num_starting_labels=0, grouping_threshold=1, separator=""):
750
+ """Plots an explanation of a string of text using coloring and interactive labels.
751
+
752
+ The output is interactive HTML and you can click on any token to toggle the display of the
753
+ SHAP value assigned to that token.
754
+ """
755
+ # See if we got hierarchical input data. If we did then we need to reprocess the
756
+ # shap_values and tokens to get the groups we want to display
757
+ warnings.warn(
758
+ "This function is not used within the shap library and will therefore be removed in an upcoming release. "
759
+ "If you rely on this function, please open an issue: https://github.com/shap/shap/issues.",
760
+ FutureWarning,
761
+ )
762
+ M = len(tokens)
763
+ if len(shap_values) != M:
764
+ # make sure we were given a partition tree
765
+ if partition_tree is None:
766
+ raise ValueError(
767
+ "The length of the attribution values must match the number of "
768
+ "tokens if partition_tree is None! When passing hierarchical "
769
+ "attributions the partition_tree is also required."
770
+ )
771
+
772
+ # compute the groups, lower_values, and max_values
773
+ groups = [[i] for i in range(M)]
774
+ lower_values = np.zeros(len(shap_values))
775
+ lower_values[:M] = shap_values[:M]
776
+ max_values = np.zeros(len(shap_values))
777
+ max_values[:M] = np.abs(shap_values[:M])
778
+ for i in range(partition_tree.shape[0]):
779
+ li = partition_tree[i, 0]
780
+ ri = partition_tree[i, 1]
781
+ groups.append(groups[li] + groups[ri])
782
+ lower_values[M + i] = lower_values[li] + lower_values[ri] + shap_values[M + i]
783
+ max_values[i + M] = max(abs(shap_values[M + i]) / len(groups[M + i]), max_values[li], max_values[ri])
784
+
785
+ # compute the upper_values
786
+ upper_values = np.zeros(len(shap_values))
787
+
788
+ def lower_credit(upper_values, partition_tree, i, value=0):
789
+ if i < M:
790
+ upper_values[i] = value
791
+ return
792
+ li = partition_tree[i - M, 0]
793
+ ri = partition_tree[i - M, 1]
794
+ upper_values[i] = value
795
+ value += shap_values[i]
796
+
797
+ lower_credit(upper_values, partition_tree, li, value * 0.5)
798
+ lower_credit(upper_values, partition_tree, ri, value * 0.5)
799
+
800
+ lower_credit(upper_values, partition_tree, len(shap_values) - 1)
801
+
802
+ # the group_values comes from the dividends above them and below them
803
+ group_values = lower_values + upper_values
804
+
805
+ # merge all the tokens in groups dominated by interaction effects (since we don't want to hide those)
806
+ new_tokens = []
807
+ new_shap_values = []
808
+ group_sizes = []
809
+
810
+ def merge_tokens(new_tokens, new_values, group_sizes, i):
811
+ # return at the leaves
812
+ if i < M and i >= 0:
813
+ new_tokens.append(tokens[i])
814
+ new_values.append(group_values[i])
815
+ group_sizes.append(1)
816
+ else:
817
+ # compute the dividend at internal nodes
818
+ li = partition_tree[i - M, 0]
819
+ ri = partition_tree[i - M, 1]
820
+ dv = abs(shap_values[i]) / len(groups[i])
821
+
822
+ # if the interaction level is too high then just treat this whole group as one token
823
+ if dv > grouping_threshold * max(max_values[li], max_values[ri]):
824
+ new_tokens.append(
825
+ separator.join([tokens[g] for g in groups[li]])
826
+ + separator
827
+ + separator.join([tokens[g] for g in groups[ri]])
828
+ )
829
+ new_values.append(group_values[i] / len(groups[i]))
830
+ group_sizes.append(len(groups[i]))
831
+ # if interaction level is not too high we recurse
832
+ else:
833
+ merge_tokens(new_tokens, new_values, group_sizes, li)
834
+ merge_tokens(new_tokens, new_values, group_sizes, ri)
835
+
836
+ merge_tokens(new_tokens, new_shap_values, group_sizes, len(group_values) - 1)
837
+
838
+ # replance the incoming parameters with the grouped versions
839
+ tokens = np.array(new_tokens)
840
+ shap_values = np.array(new_shap_values)
841
+ group_sizes = np.array(group_sizes)
842
+ M = len(tokens)
843
+ else:
844
+ group_sizes = np.ones(M)
845
+
846
+ # build out HTML output one word one at a time
847
+ top_inds = np.argsort(-np.abs(shap_values))[:num_starting_labels]
848
+ maxv = shap_values.max()
849
+ minv = shap_values.min()
850
+ out = ""
851
+ for i in range(M):
852
+ scaled_value = 0.5 + 0.5 * shap_values[i] / max(abs(maxv), abs(minv))
853
+ color = colors.red_transparent_blue(scaled_value)
854
+ color = (float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3]))
855
+
856
+ # display the labels for the most important words
857
+ label_display = "none"
858
+ wrapper_display = "inline"
859
+ if i in top_inds:
860
+ label_display = "block"
861
+ wrapper_display = "inline-block"
862
+
863
+ # create the value_label string
864
+ value_label = ""
865
+ if group_sizes[i] == 1:
866
+ value_label = str(shap_values[i].round(3))
867
+ else:
868
+ value_label = str((shap_values[i] * group_sizes[i]).round(3)) + " / " + str(group_sizes[i])
869
+
870
+ # the HTML for this token
871
+ out += (
872
+ "<div style='display: "
873
+ + wrapper_display
874
+ + "; text-align: center;'>"
875
+ + "<div style='display: "
876
+ + label_display
877
+ + "; color: #999; padding-top: 0px; font-size: 12px;'>"
878
+ + value_label
879
+ + "</div>"
880
+ + "<div "
881
+ + "style='display: inline; background: rgba"
882
+ + str(color)
883
+ + "; border-radius: 3px; padding: 0px'"
884
+ + "onclick=\"if (this.previousSibling.style.display == 'none') {"
885
+ + "this.previousSibling.style.display = 'block';"
886
+ + "this.parentNode.style.display = 'inline-block';"
887
+ + "} else {"
888
+ + "this.previousSibling.style.display = 'none';"
889
+ + "this.parentNode.style.display = 'inline';"
890
+ + "}"
891
+ + '"'
892
+ + ">"
893
+ + tokens[i].replace("<", "&lt;").replace(">", "&gt;").replace(" ##", "")
894
+ + "</div>"
895
+ + "</div>"
896
+ )
897
+
898
+ return _ipython_display_html(out)
899
+
900
+
901
+ def text_to_text(shap_values):
902
+ # unique ID added to HTML elements and function to avoid collision of different instances
903
+ uuid = "".join(random.choices(string.ascii_lowercase, k=20))
904
+
905
+ saliency_plot_markup = saliency_plot(shap_values)
906
+ heatmap_markup = heatmap(shap_values)
907
+
908
+ html = f"""
909
+ <html>
910
+ <div id="{uuid}_viz_container">
911
+ <div id="{uuid}_viz_header" style="padding:15px;border-style:solid;margin:5px;font-family:sans-serif;font-weight:bold;">
912
+ Visualization Type:
913
+ <select name="viz_type" id="{uuid}_viz_type" onchange="selectVizType_{uuid}(this)">
914
+ <option value="heatmap" selected="selected">Input/Output - Heatmap</option>
915
+ <option value="saliency-plot">Saliency Plot</option>
916
+ </select>
917
+ </div>
918
+ <div id="{uuid}_content" style="padding:15px;border-style:solid;margin:5px;">
919
+ <div id = "{uuid}_saliency_plot_container" class="{uuid}_viz_container" style="display:none">
920
+ {saliency_plot_markup}
921
+ </div>
922
+
923
+ <div id = "{uuid}_heatmap_container" class="{uuid}_viz_container">
924
+ {heatmap_markup}
925
+ </div>
926
+ </div>
927
+ </div>
928
+ </html>
929
+ """
930
+
931
+ javascript = f"""
932
+ <script>
933
+ function selectVizType_{uuid}(selectObject) {{
934
+
935
+ /* Hide all viz */
936
+
937
+ var elements = document.getElementsByClassName("{uuid}_viz_container")
938
+ for (var i = 0; i < elements.length; i++){{
939
+ elements[i].style.display = 'none';
940
+ }}
941
+
942
+ var value = selectObject.value;
943
+ if ( value === "saliency-plot" ){{
944
+ document.getElementById('{uuid}_saliency_plot_container').style.display = "block";
945
+ }}
946
+ else if ( value === "heatmap" ) {{
947
+ document.getElementById('{uuid}_heatmap_container').style.display = "block";
948
+ }}
949
+ }}
950
+ </script>
951
+ """
952
+
953
+ _ipython_display_html(javascript + html)
954
+
955
+
956
+ def saliency_plot(shap_values):
957
+ uuid = "".join(random.choices(string.ascii_lowercase, k=20))
958
+
959
+ unpacked_values, clustering = unpack_shap_explanation_contents(shap_values)
960
+ tokens, values, group_sizes, token_id_to_node_id_mapping, collapsed_node_ids = process_shap_values(
961
+ shap_values.data, unpacked_values[:, 0], 1, "", clustering, True
962
+ )
963
+
964
+ def compress_shap_matrix(shap_matrix, group_sizes):
965
+ compressed_matrix = np.zeros((group_sizes.shape[0], shap_matrix.shape[1]))
966
+ counter = 0
967
+ for index in range(len(group_sizes)):
968
+ compressed_matrix[index, :] = np.sum(shap_matrix[counter : counter + group_sizes[index], :], axis=0)
969
+ counter += group_sizes[index]
970
+
971
+ return compressed_matrix
972
+
973
+ compressed_shap_matrix = compress_shap_matrix(shap_values.values, group_sizes)
974
+
975
+ # generate background colors of saliency plot
976
+
977
+ def get_colors(shap_values):
978
+ input_colors = []
979
+ cmax = max(abs(compressed_shap_matrix.min()), abs(compressed_shap_matrix.max()))
980
+ for row_index in range(compressed_shap_matrix.shape[0]):
981
+ input_colors_row = []
982
+ for col_index in range(compressed_shap_matrix.shape[1]):
983
+ scaled_value = 0.5 + 0.5 * compressed_shap_matrix[row_index, col_index] / cmax
984
+ color = colors.red_transparent_blue(scaled_value)
985
+ color = "rgba" + str((float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3])))
986
+ input_colors_row.append(color)
987
+ input_colors.append(input_colors_row)
988
+
989
+ return input_colors
990
+
991
+ model_output = shap_values.output_names
992
+
993
+ input_colors = get_colors(shap_values)
994
+
995
+ out = '<table border = "1" cellpadding = "5" cellspacing = "5" style="overflow-x:scroll;display:block;">'
996
+
997
+ # add top row containing input tokens
998
+ out += "<tr>"
999
+ out += "<th></th>"
1000
+
1001
+ for j in range(compressed_shap_matrix.shape[0]):
1002
+ out += (
1003
+ "<th>"
1004
+ + tokens[j].replace("<", "&lt;").replace(">", "&gt;").replace(" ##", "").replace("▁", "").replace("Ġ", "")
1005
+ + "</th>"
1006
+ )
1007
+ out += "</tr>"
1008
+
1009
+ for row_index in range(compressed_shap_matrix.shape[1]):
1010
+ out += "<tr>"
1011
+ out += (
1012
+ "<th>"
1013
+ + model_output[row_index]
1014
+ .replace("<", "&lt;")
1015
+ .replace(">", "&gt;")
1016
+ .replace(" ##", "")
1017
+ .replace("▁", "")
1018
+ .replace("Ġ", "")
1019
+ + "</th>"
1020
+ )
1021
+ for col_index in range(compressed_shap_matrix.shape[0]):
1022
+ out += (
1023
+ '<th style="background:'
1024
+ + input_colors[col_index][row_index]
1025
+ + '">'
1026
+ + str(round(compressed_shap_matrix[col_index][row_index], 3))
1027
+ + "</th>"
1028
+ )
1029
+ out += "</tr>"
1030
+
1031
+ out += "</table>"
1032
+
1033
+ saliency_plot_html = f"""
1034
+ <div id="{uuid}_saliency_plot" class="{uuid}_viz_content">
1035
+ <div style="margin:5px;font-family:sans-serif;font-weight:bold;">
1036
+ <span style="font-size: 20px;"> Saliency Plot </span>
1037
+ <br>
1038
+ x-axis: Output Text
1039
+ <br>
1040
+ y-axis: Input Text
1041
+ </div>
1042
+ {out}
1043
+ </div>
1044
+ """
1045
+ return saliency_plot_html
1046
+
1047
+
1048
+ def heatmap(shap_values):
1049
+ # constants
1050
+
1051
+ TREE_NODE_KEY_TOKENS = "tokens"
1052
+ TREE_NODE_KEY_CHILDREN = "children"
1053
+
1054
+ uuid = "".join(random.choices(string.ascii_lowercase, k=20))
1055
+
1056
+ def get_color(shap_value, cmax):
1057
+ scaled_value = 0.5 + 0.5 * shap_value / cmax
1058
+ color = colors.red_transparent_blue(scaled_value)
1059
+ color = (float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3]))
1060
+ return color
1061
+
1062
+ def process_text_to_text_shap_values(shap_values):
1063
+ processed_values = []
1064
+
1065
+ unpacked_values, clustering = unpack_shap_explanation_contents(shap_values)
1066
+ max_val = 0
1067
+
1068
+ for index, output_token in enumerate(shap_values.output_names):
1069
+ tokens, values, group_sizes, token_id_to_node_id_mapping, collapsed_node_ids = process_shap_values(
1070
+ shap_values.data, unpacked_values[:, index], 1, "", clustering, True
1071
+ )
1072
+ processed_value = {
1073
+ "tokens": tokens,
1074
+ "values": values,
1075
+ "group_sizes": group_sizes,
1076
+ "token_id_to_node_id_mapping": token_id_to_node_id_mapping,
1077
+ "collapsed_node_ids": collapsed_node_ids,
1078
+ }
1079
+
1080
+ processed_values.append(processed_value)
1081
+ max_val = max(max_val, np.max(values))
1082
+
1083
+ return processed_values, max_val
1084
+
1085
+ # unpack input tokens and output tokens
1086
+ model_input = shap_values.data
1087
+ model_output = shap_values.output_names
1088
+
1089
+ processed_values, max_val = process_text_to_text_shap_values(shap_values)
1090
+
1091
+ # generate dictionary containing precomputed background colors and shap values which are addressable by html token ids
1092
+ colors_dict = {}
1093
+ shap_values_dict = {}
1094
+ token_id_to_node_id_mapping = {}
1095
+ cmax = max(abs(shap_values.values.min()), abs(shap_values.values.max()), max_val)
1096
+
1097
+ # input token -> output token color and label value mapping
1098
+
1099
+ for row_index in range(len(model_input)):
1100
+ color_values = {}
1101
+ shap_values_list = {}
1102
+
1103
+ for col_index in range(len(model_output)):
1104
+ color_values[uuid + "_output_flat_token_" + str(col_index)] = "rgba" + str(
1105
+ get_color(shap_values.values[row_index][col_index], cmax)
1106
+ )
1107
+ shap_values_list[uuid + "_output_flat_value_label_" + str(col_index)] = round(
1108
+ shap_values.values[row_index][col_index], 3
1109
+ )
1110
+
1111
+ colors_dict[f"{uuid}_input_node_{row_index}_content"] = color_values
1112
+ shap_values_dict[f"{uuid}_input_node_{row_index}_content"] = shap_values_list
1113
+
1114
+ # output token -> input token color and label value mapping
1115
+
1116
+ for col_index in range(len(model_output)):
1117
+ color_values = {}
1118
+ shap_values_list = {}
1119
+
1120
+ for row_index in range(processed_values[col_index]["collapsed_node_ids"].shape[0]):
1121
+ color_values[
1122
+ uuid + "_input_node_" + str(processed_values[col_index]["collapsed_node_ids"][row_index]) + "_content"
1123
+ ] = "rgba" + str(get_color(processed_values[col_index]["values"][row_index], cmax))
1124
+ shap_label_value_str = str(round(processed_values[col_index]["values"][row_index], 3))
1125
+ if processed_values[col_index]["group_sizes"][row_index] > 1:
1126
+ shap_label_value_str += "/" + str(processed_values[col_index]["group_sizes"][row_index])
1127
+
1128
+ shap_values_list[
1129
+ uuid + "_input_node_" + str(processed_values[col_index]["collapsed_node_ids"][row_index]) + "_label"
1130
+ ] = shap_label_value_str
1131
+
1132
+ colors_dict[uuid + "_output_flat_token_" + str(col_index)] = color_values
1133
+ shap_values_dict[uuid + "_output_flat_token_" + str(col_index)] = shap_values_list
1134
+
1135
+ token_id_to_node_id_mapping_dict = {}
1136
+
1137
+ for index, node_id in enumerate(processed_values[col_index]["token_id_to_node_id_mapping"].tolist()):
1138
+ token_id_to_node_id_mapping_dict[f"{uuid}_input_node_{index}_content"] = (
1139
+ f"{uuid}_input_node_{int(node_id)}_content"
1140
+ )
1141
+
1142
+ token_id_to_node_id_mapping[uuid + "_output_flat_token_" + str(col_index)] = token_id_to_node_id_mapping_dict
1143
+
1144
+ # convert python dictionary into json to be inserted into the runtime javascript environment
1145
+ colors_json = json.dumps(colors_dict)
1146
+ shap_values_json = json.dumps(shap_values_dict)
1147
+ token_id_to_node_id_mapping_json = json.dumps(token_id_to_node_id_mapping)
1148
+
1149
+ javascript_values = (
1150
+ "<script> "
1151
+ f"colors_{uuid} = {colors_json}\n"
1152
+ f" shap_values_{uuid} = {shap_values_json}\n"
1153
+ f" token_id_to_node_id_mapping_{uuid} = {token_id_to_node_id_mapping_json}\n"
1154
+ "</script> \n "
1155
+ )
1156
+
1157
+ def generate_tree(shap_values):
1158
+ num_tokens = len(shap_values.data)
1159
+ token_list = {}
1160
+
1161
+ for index in range(num_tokens):
1162
+ node_content = {}
1163
+ node_content[TREE_NODE_KEY_TOKENS] = shap_values.data[index]
1164
+ node_content[TREE_NODE_KEY_CHILDREN] = {}
1165
+ token_list[str(index)] = node_content
1166
+
1167
+ counter = num_tokens
1168
+ for pair in shap_values.clustering:
1169
+ first_node = str(int(pair[0]))
1170
+ second_node = str(int(pair[1]))
1171
+
1172
+ new_node_content = {}
1173
+ new_node_content[TREE_NODE_KEY_CHILDREN] = {
1174
+ first_node: token_list[first_node],
1175
+ second_node: token_list[second_node],
1176
+ }
1177
+
1178
+ token_list[str(counter)] = new_node_content
1179
+ counter += 1
1180
+
1181
+ del token_list[first_node]
1182
+ del token_list[second_node]
1183
+
1184
+ return token_list
1185
+
1186
+ tree = generate_tree(shap_values)
1187
+
1188
+ # generates the input token html elements
1189
+ # each element contains the label value (initially hidden) and the token text
1190
+
1191
+ input_text_html = ""
1192
+
1193
+ def populate_input_tree(input_index, token_list_subtree, input_text_html):
1194
+ content = token_list_subtree[input_index]
1195
+ input_text_html += (
1196
+ f'<div id="{uuid}_input_node_{input_index}_container" style="display:inline;text-align:center">'
1197
+ )
1198
+
1199
+ input_text_html += (
1200
+ f'<div id="{uuid}_input_node_{input_index}_label" style="display:none; padding-top: 0px; font-size:12px;">'
1201
+ )
1202
+
1203
+ input_text_html += "</div>"
1204
+
1205
+ if token_list_subtree[input_index][TREE_NODE_KEY_CHILDREN]:
1206
+ input_text_html += f'<div id="{uuid}_input_node_{input_index}_content" style="display:inline;">'
1207
+ for child_index, child_content in token_list_subtree[input_index][TREE_NODE_KEY_CHILDREN].items():
1208
+ input_text_html = populate_input_tree(
1209
+ child_index, token_list_subtree[input_index][TREE_NODE_KEY_CHILDREN], input_text_html
1210
+ )
1211
+ input_text_html += "</div>"
1212
+ else:
1213
+ input_text_html += (
1214
+ f'<div id="{uuid}_input_node_{input_index}_content"'
1215
+ "style='display: inline; background:transparent; border-radius: 3px; padding: 0px;cursor: default;cursor: pointer;'"
1216
+ f'onmouseover="onMouseHoverFlat_{uuid}(this.id)" '
1217
+ f'onmouseout="onMouseOutFlat_{uuid}(this.id)" '
1218
+ f'onclick="onMouseClickFlat_{uuid}(this.id)" '
1219
+ ">"
1220
+ )
1221
+ input_text_html += (
1222
+ content[TREE_NODE_KEY_TOKENS]
1223
+ .replace("<", "&lt;")
1224
+ .replace(">", "&gt;")
1225
+ .replace(" ##", "")
1226
+ .replace("▁", "")
1227
+ .replace("Ġ", "")
1228
+ )
1229
+ input_text_html += "</div>"
1230
+
1231
+ input_text_html += "</div>"
1232
+
1233
+ return input_text_html
1234
+
1235
+ input_text_html = populate_input_tree(list(tree.keys())[0], tree, input_text_html)
1236
+
1237
+ # generates the output token html elements
1238
+ output_text_html = ""
1239
+
1240
+ for i in range(len(model_output)):
1241
+ output_text_html += (
1242
+ "<div style='display:inline; text-align:center;'>"
1243
+ f"<div id='{uuid}_output_flat_value_label_{i}'"
1244
+ "style='display:none;color: #999; padding-top: 0px; font-size:12px;'>"
1245
+ "</div>"
1246
+ f"<div id='{uuid}_output_flat_token_{i}'"
1247
+ "style='display: inline; background:transparent; border-radius: 3px; padding: 0px;cursor: default;cursor: pointer;'"
1248
+ f'onmouseover="onMouseHoverFlat_{uuid}(this.id)" '
1249
+ f'onmouseout="onMouseOutFlat_{uuid}(this.id)" '
1250
+ f'onclick="onMouseClickFlat_{uuid}(this.id)" '
1251
+ ">"
1252
+ + model_output[i]
1253
+ .replace("<", "&lt;")
1254
+ .replace(">", "&gt;")
1255
+ .replace(" ##", "")
1256
+ .replace("▁", "")
1257
+ .replace("Ġ", "")
1258
+ + " </div>"
1259
+ + "</div>"
1260
+ )
1261
+
1262
+ heatmap_html = f"""
1263
+ <div id="{uuid}_heatmap" class="{uuid}_viz_content">
1264
+ <div id="{uuid}_heatmap_header" style="padding:15px;margin:5px;font-family:sans-serif;font-weight:bold;">
1265
+ <div style="display:inline">
1266
+ <span style="font-size: 20px;"> Input/Output - Heatmap </span>
1267
+ </div>
1268
+ <div style="display:inline;float:right">
1269
+ Layout :
1270
+ <select name="alignment" id="{uuid}_alignment" onchange="selectAlignment_{uuid}(this)">
1271
+ <option value="left-right" selected="selected">Left/Right</option>
1272
+ <option value="top-bottom">Top/Bottom</option>
1273
+ </select>
1274
+ </div>
1275
+ </div>
1276
+ <div id="{uuid}_heatmap_content" style="display:flex;">
1277
+ <div id="{uuid}_input_container" style="padding:15px;border-style:solid;margin:5px;flex:1;">
1278
+ <div id="{uuid}_input_header" style="margin:5px;font-weight:bold;font-family:sans-serif;margin-bottom:10px">
1279
+ Input Text
1280
+ </div>
1281
+ <div id="{uuid}_input_content" style="margin:5px;font-family:sans-serif;">
1282
+ {input_text_html}
1283
+ </div>
1284
+ </div>
1285
+ <div id="{uuid}_output_container" style="padding:15px;border-style:solid;margin:5px;flex:1;">
1286
+ <div id="{uuid}_output_header" style="margin:5px;font-weight:bold;font-family:sans-serif;margin-bottom:10px">
1287
+ Output Text
1288
+ </div>
1289
+ <div id="{uuid}_output_content" style="margin:5px;font-family:sans-serif;">
1290
+ {output_text_html}
1291
+ </div>
1292
+ </div>
1293
+ </div>
1294
+ </div>
1295
+ """
1296
+
1297
+ heatmap_javascript = f"""
1298
+ <script>
1299
+ function selectAlignment_{uuid}(selectObject) {{
1300
+ var value = selectObject.value;
1301
+ if ( value === "left-right" ){{
1302
+ document.getElementById('{uuid}_heatmap_content').style.display = "flex";
1303
+ }}
1304
+ else if ( value === "top-bottom" ) {{
1305
+ document.getElementById('{uuid}_heatmap_content').style.display = "inline";
1306
+ }}
1307
+ }}
1308
+
1309
+ var {uuid}_heatmap_flat_state = null;
1310
+
1311
+ function onMouseHoverFlat_{uuid}(id) {{
1312
+ if ({uuid}_heatmap_flat_state === null) {{
1313
+ setBackgroundColors_{uuid}(id);
1314
+ document.getElementById(id).style.backgroundColor = "grey";
1315
+ }}
1316
+
1317
+ if (getIdSide_{uuid}(id) === 'input' && getIdSide_{uuid}({uuid}_heatmap_flat_state) === 'output'){{
1318
+
1319
+ label_content_id = token_id_to_node_id_mapping_{uuid}[{uuid}_heatmap_flat_state][id];
1320
+
1321
+ if (document.getElementById(label_content_id).previousElementSibling.style.display == 'none'){{
1322
+ document.getElementById(label_content_id).style.textShadow = "0px 0px 1px #000000";
1323
+ }}
1324
+
1325
+ }}
1326
+
1327
+ }}
1328
+
1329
+ function onMouseOutFlat_{uuid}(id) {{
1330
+ if ({uuid}_heatmap_flat_state === null) {{
1331
+ cleanValuesAndColors_{uuid}(id);
1332
+ document.getElementById(id).style.backgroundColor = "transparent";
1333
+ }}
1334
+
1335
+ if (getIdSide_{uuid}(id) === 'input' && getIdSide_{uuid}({uuid}_heatmap_flat_state) === 'output'){{
1336
+
1337
+ label_content_id = token_id_to_node_id_mapping_{uuid}[{uuid}_heatmap_flat_state][id];
1338
+
1339
+ if (document.getElementById(label_content_id).previousElementSibling.style.display == 'none'){{
1340
+ document.getElementById(label_content_id).style.textShadow = "inherit";
1341
+ }}
1342
+
1343
+ }}
1344
+
1345
+ }}
1346
+
1347
+ function onMouseClickFlat_{uuid}(id) {{
1348
+ if ({uuid}_heatmap_flat_state === id) {{
1349
+
1350
+ // If the clicked token was already selected
1351
+
1352
+ document.getElementById(id).style.backgroundColor = "transparent";
1353
+ cleanValuesAndColors_{uuid}(id);
1354
+ {uuid}_heatmap_flat_state = null;
1355
+ }}
1356
+ else {{
1357
+ if ({uuid}_heatmap_flat_state === null) {{
1358
+
1359
+ // No token previously selected, new token clicked on
1360
+
1361
+ cleanValuesAndColors_{uuid}(id)
1362
+ {uuid}_heatmap_flat_state = id;
1363
+ document.getElementById(id).style.backgroundColor = "grey";
1364
+ setLabelValues_{uuid}(id);
1365
+ setBackgroundColors_{uuid}(id);
1366
+ }}
1367
+ else {{
1368
+ if (getIdSide_{uuid}({uuid}_heatmap_flat_state) === getIdSide_{uuid}(id)) {{
1369
+
1370
+ // User clicked a token on the same side as the currently selected token
1371
+
1372
+ cleanValuesAndColors_{uuid}({uuid}_heatmap_flat_state)
1373
+ document.getElementById({uuid}_heatmap_flat_state).style.backgroundColor = "transparent";
1374
+ {uuid}_heatmap_flat_state = id;
1375
+ document.getElementById(id).style.backgroundColor = "grey";
1376
+ setLabelValues_{uuid}(id);
1377
+ setBackgroundColors_{uuid}(id);
1378
+ }}
1379
+ else{{
1380
+
1381
+ if (getIdSide_{uuid}(id) === 'input') {{
1382
+ label_content_id = token_id_to_node_id_mapping_{uuid}[{uuid}_heatmap_flat_state][id];
1383
+
1384
+ if (document.getElementById(label_content_id).previousElementSibling.style.display == 'none') {{
1385
+ document.getElementById(label_content_id).previousElementSibling.style.display = 'block';
1386
+ document.getElementById(label_content_id).parentNode.style.display = 'inline-block';
1387
+ document.getElementById(label_content_id).style.textShadow = "0px 0px 1px #000000";
1388
+ }}
1389
+ else {{
1390
+ document.getElementById(label_content_id).previousElementSibling.style.display = 'none';
1391
+ document.getElementById(label_content_id).parentNode.style.display = 'inline';
1392
+ document.getElementById(label_content_id).style.textShadow = "inherit";
1393
+ }}
1394
+
1395
+ }}
1396
+ else {{
1397
+ if (document.getElementById(id).previousElementSibling.style.display == 'none') {{
1398
+ document.getElementById(id).previousElementSibling.style.display = 'block';
1399
+ document.getElementById(id).parentNode.style.display = 'inline-block';
1400
+ }}
1401
+ else {{
1402
+ document.getElementById(id).previousElementSibling.style.display = 'none';
1403
+ document.getElementById(id).parentNode.style.display = 'inline';
1404
+ }}
1405
+ }}
1406
+
1407
+ }}
1408
+ }}
1409
+
1410
+ }}
1411
+ }}
1412
+
1413
+ function setLabelValues_{uuid}(id) {{
1414
+ for(const token in shap_values_{uuid}[id]){{
1415
+ document.getElementById(token).innerHTML = shap_values_{uuid}[id][token];
1416
+ document.getElementById(token).nextElementSibling.title = 'SHAP Value : ' + shap_values_{uuid}[id][token];
1417
+ }}
1418
+ }}
1419
+
1420
+ function setBackgroundColors_{uuid}(id) {{
1421
+ for(const token in colors_{uuid}[id]){{
1422
+ document.getElementById(token).style.backgroundColor = colors_{uuid}[id][token];
1423
+ }}
1424
+ }}
1425
+
1426
+ function cleanValuesAndColors_{uuid}(id) {{
1427
+ for(const token in shap_values_{uuid}[id]){{
1428
+ document.getElementById(token).innerHTML = "";
1429
+ document.getElementById(token).nextElementSibling.title = "";
1430
+ }}
1431
+ for(const token in colors_{uuid}[id]){{
1432
+ document.getElementById(token).style.backgroundColor = "transparent";
1433
+ document.getElementById(token).previousElementSibling.style.display = 'none';
1434
+ document.getElementById(token).parentNode.style.display = 'inline';
1435
+ document.getElementById(token).style.textShadow = "inherit";
1436
+ }}
1437
+ }}
1438
+
1439
+ function getIdSide_{uuid}(id) {{
1440
+ if (id === null) {{
1441
+ return 'null'
1442
+ }}
1443
+ return id.split("_")[1];
1444
+ }}
1445
+ </script>
1446
+ """
1447
+
1448
+ return heatmap_html + heatmap_javascript + javascript_values
1449
+
1450
+
1451
+ def unpack_shap_explanation_contents(shap_values):
1452
+ values = getattr(shap_values, "hierarchical_values", None)
1453
+ if values is None:
1454
+ values = shap_values.values
1455
+ clustering = getattr(shap_values, "clustering", None)
1456
+
1457
+ return np.array(values), clustering
1458
+
1459
+
1460
+ def _ipython_display_html(data):
1461
+ """Check IPython is installed, then display HTML"""
1462
+ if not have_ipython:
1463
+ msg = "IPython is required for this function but is not installed. Fix this with `pip install ipython`."
1464
+ raise ImportError(msg)
1465
+ return ipython_display(HTML(data))