{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Example 3: Custom value functions\n", "#### These examples are from section 5.1 of the Faith-Shap paper: https://arxiv.org/abs/2203.00870" ] }, { "cell_type": "code", "execution_count": 1, "id": "c10cfd76", "metadata": {}, "outputs": [], "source": [ "import seaborn as sns\n", "\n", "sns.set_style(\"whitegrid\")\n", "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.5)\n", "\n", "import numpy as np\n", "import math\n", "\n", "import nshap" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f0b5f063", "metadata": {}, "source": [ "### Example 1\n", "##### A value function takes two arguments: A single data point x (a numpy.ndarray) and a python list S with the indices the the coordinates that belong to the coaltion.\n", "##### In this example, the value does not actually depend on the point x" ] }, { "cell_type": "code", "execution_count": 7, "id": "44978e93", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "3.4" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p = 0.1\n", "def v_func(x, S):\n", " \"\"\" The value function from Example 1 in the Faith-Shap paper.\n", " \"\"\"\n", " if len(S) <= 1:\n", " return 0\n", " return len(S) - p * math.comb(len(S), 2)\n", "\n", "\n", "v_func(np.zeros((1,11)), [1,2,3,4])" ] }, { "attachments": {}, "cell_type": "markdown", "id": "86a0f974", "metadata": {}, "source": [ "##### Equipped with the value function, we can compute different kinds of interaction indices" ] }, { "cell_type": "code", "execution_count": null, "id": "fbbbbcf5", "metadata": {}, "outputs": [], "source": [ "faith_shap = nshap.faith_shap(np.zeros((1,11)), v_func, 1)\n", "faith_shap[(0,1)], faith_shap[(1,2)]" ] }, { "attachments": {}, "cell_type": "markdown", "id": "08657d28", "metadata": {}, "source": [ "##### We can replicate Table 1 in the Faith-Shap paper" ] }, { "cell_type": "code", "execution_count": 13, "id": "d4bcee91", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Table 1 in the Faith-Shap paper:\n", "\n", "p=0.1 l=1\n", "\n", "Faith-Shap: 0.5000000000001867\n", "Shapley Taylor: 0.500000000000185\n", "Interaction Shapley: 0.4999999999999939\n", "Banzhaf Interaction: 0.5087890624999979\n", "Faith-Banzhaf: 0.5087890624999989\n", "\n", "p=0.1 l=2\n", "\n", "Faith-Shap: 0.9545454545463214 -0.09090909090923938\n", "Shapley Taylor: 0 0.10000000000002415\n", "Interaction Shapley: 0.4999999999999939 -5.759281940243e-16\n", "Banzhaf Interaction: 0.5087890624999979 -0.11367187500000073\n", "Faith-Banzhaf: 1.0771484374999503 -0.11367187499999737\n", "\n", "p=0.2 l=1\n", "\n", "Faith-Shap: 1.6486811915683575e-13\n", "Shapley Taylor: 1.7552626019323725e-13\n", "Interaction Shapley: 2.7755575615628914e-17\n", "Banzhaf Interaction: 0.0087890625\n", "Faith-Banzhaf: 0.008789062500008604\n", "\n", "p=0.2 l=2\n", "\n", "Faith-Shap: 0.9545454545460967 -0.19090909090925617\n", "Shapley Taylor: 0 -3.7941871866564725e-14\n", "Interaction Shapley: 2.7755575615628914e-17 -0.10000000000000021\n", "Banzhaf Interaction: 0.0087890625 -0.21367187500000198\n", "Faith-Banzhaf: 1.0771484374999574 -0.21367187499998952\n", "\n" ] } ], "source": [ "print('Table 1 in the Faith-Shap paper:\\n')\n", "for p in [0.1, 0.2]:\n", " for l in [1, 2]:\n", " print(f'p={p} l={l}\\n')\n", "\n", " # define the value function\n", " def v_func(x, S):\n", " if len(S) <= 1:\n", " return 0\n", " return len(S) - p * math.comb(len(S), 2)\n", "\n", " # compute interaction indices\n", " faith_shap = nshap.faith_shap(np.zeros((1,11)), v_func, l)\n", " shapley_taylor = nshap.shapley_taylor(np.zeros((1,11)), v_func, l)\n", " shapley_interaction = nshap.shapley_interaction_index(np.zeros((1,11)), v_func, l)\n", " banzhaf_interaction = nshap.banzhaf_interaction_index(np.zeros((1,11)), v_func, l)\n", " faith_banzhaf = nshap.faith_banzhaf(np.zeros((1,11)), v_func, l)\n", "\n", " # print result\n", " if l == 1:\n", " print('Faith-Shap: ', faith_shap[(0,)])\n", " print('Shapley Taylor: ', shapley_taylor[(0,)])\n", " print('Interaction Shapley: ', shapley_interaction[(0,)])\n", " print('Banzhaf Interaction: ', banzhaf_interaction[(0,)])\n", " print('Faith-Banzhaf: ', faith_banzhaf[(0,)])\n", " else:\n", " print('Faith-Shap: ', faith_shap[(0,)], faith_shap[(0,1)])\n", " print('Shapley Taylor: ', shapley_taylor[(0,)], shapley_taylor[(0,1)])\n", " print('Interaction Shapley: ', shapley_interaction[(0,)], shapley_interaction[(0,1)])\n", " print('Banzhaf Interaction: ', banzhaf_interaction[(0,)], banzhaf_interaction[(0,1)])\n", " print('Faith-Banzhaf: ', faith_banzhaf[(0,)], faith_banzhaf[(0,1)])\n", " print('')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.12 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "vscode": { "interpreter": { "hash": "62406f3c2942480da828869ab3f3f95d1c0177b3689d5bc770f3ddfd7b9b3df5" } } }, "nbformat": 4, "nbformat_minor": 5 }