Skip to content

Commit

Permalink
fix shape check to support other types of wrapped datasets such as Ba…
Browse files Browse the repository at this point in the history
…tchDataset (#510)
  • Loading branch information
imatiach-msft authored Apr 11, 2022
1 parent 75fdf32 commit 1f2d978
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 11 deletions.
36 changes: 27 additions & 9 deletions python/interpret_community/common/blackbox_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,21 @@ def __init__(self, model, is_function=False, model_task=ModelTask.Unknown, **kwa
if model_task == ModelTask.Classification:
raise Exception("No predict_proba method on model which has model_task='classifier'")

def _get_ys_dict(self, evaluation_examples, transformations=None, allow_all_transformations=False):
def _get_ys_dict(self, evaluation_examples, transformations=None,
allow_all_transformations=False):
"""Get the predicted ys to be incorporated into a kwargs dictionary.
:param evaluation_examples: The same ones we usually work with, must be able to be passed into the
model or function.
:type evaluation_examples: numpy.ndarray or pandas.DataFrame or scipy.sparse.csr_matrix
:param evaluation_examples: The same ones we usually work with,
must be able to be passed into the model or function.
:type evaluation_examples: DatasetWrapper or numpy.ndarray or
pandas.DataFrame or scipy.sparse.csr_matrix
:param transformations: See documentation on any explainer.
:type transformations: sklearn.compose.ColumnTransformer or list[tuple]
:param allow_all_transformations: Allow many to many and many to one transformations
:param allow_all_transformations: Allow many to many and many to
one transformations
:type allow_all_transformations: bool
:return: The dictionary with none, one, or both of predicted ys and predicted proba ys for eval
examples.
:return: The dictionary with none, one, or both of predicted ys
and predicted proba ys for eval examples.
:rtype: dict
"""
if transformations is not None:
Expand All @@ -123,7 +127,7 @@ def _get_ys_dict(self, evaluation_examples, transformations=None, allow_all_tran
)
if isinstance(evaluation_examples, DatasetWrapper):
evaluation_examples = evaluation_examples.original_dataset_with_type
if len(evaluation_examples.shape) == 1:
if hasattr(evaluation_examples, 'shape') and len(evaluation_examples.shape) == 1:
evaluation_examples = evaluation_examples.reshape(1, -1)
ys_dict = {}
if self.model is not None:
Expand Down Expand Up @@ -238,7 +242,21 @@ def wrapper(data):
return wrapper

def _prepare_function_and_summary(self, function, original_data_ref,
current_index_list, explain_subset=None, **kwargs):
current_index_list,
explain_subset=None, **kwargs):
"""Prepare the initialization dataset and the wrapper function for predictions.
:param function: The prediction function.
:type function: function
:param original_data_ref: The original data reference.
:type original_data_ref: list
:param current_index_list: Pointer to the current row to be evaluated.
:type current_index_list: list
:param explain_subset: The subset of feature indexes to explain.
:type explain_subset: list
:return: The prepared function and the summary dataset.
:rtype: function, ndarray
"""
if explain_subset:
# Note: need to take subset before compute summary
self.initialization_examples.take_subset(explain_subset)
Expand Down
32 changes: 31 additions & 1 deletion tests/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

try:
from tensorflow import keras
from tensorflow.keras.layers import Activation, Dense, Dropout
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Activation, Dense, Dropout, concatenate
from tensorflow.keras.models import Sequential
from tensorflow.keras.wrappers.scikit_learn import (KerasClassifier,
KerasRegressor)
Expand Down Expand Up @@ -328,6 +329,35 @@ def create_pytorch_regressor(X, y):
return _train_pytorch_model(epochs, criterion, optimizer, net, torch_X, torch_y)


def create_tf_model(inp_ds, val_ds, feature_names):
"""Create a simple TF model for regression.
:param inp_ds: input data set.
:type inp_ds: BatchDataset
:param val_ds: validation data set.
:type val_ds: BatchDataset
:param feature_names: list of feature names.
:type feature_names: list
:return: a TF model.
:rtype: tf.keras.Model
"""
inputs = {col: Input(name=col, shape=(1,),
dtype='float32') for col in list(feature_names)}

x = concatenate([inputs[col] for col in list(feature_names)])
x = Dense(20, activation='relu', name='hidden1')(x)
out = Dense(1)(x)

model = Model(inputs=inputs, outputs=out)

model.compile(optimizer='adam',
loss='mse',
metrics=['mae', 'mse'])

model.fit(inp_ds, epochs=5, validation_data=val_ds)
return model


def create_keras_classifier(X, y):
# create simple (dummy) Keras DNN model for binary classification
batch_size = 128
Expand Down
39 changes: 38 additions & 1 deletion tests/test_explain_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
create_sklearn_linear_regressor,
create_sklearn_random_forest_classifier,
create_sklearn_random_forest_regressor,
create_sklearn_svm_classifier,
create_sklearn_svm_classifier, create_tf_model,
create_xgboost_classifier,
wrap_classifier_without_proba)
from constants import DatasetConstants, owner_email_tools_and_ux
Expand All @@ -30,6 +30,7 @@
TreeExplainer)
from interpret_community.tabular_explainer import _get_uninitialized_explainers
from lightgbm import LGBMClassifier
from ml_wrappers import DatasetWrapper, wrap_model
from raw_explain.utils import _get_feature_map_from_indices_list
from scipy.sparse import csr_matrix
from sklearn.compose import ColumnTransformer
Expand All @@ -40,9 +41,16 @@
from sklearn.preprocessing import (FunctionTransformer, OneHotEncoder,
StandardScaler)

try:
import tensorflow as tf
except ImportError:
pass


test_logger = logging.getLogger(__name__)
test_logger.setLevel(logging.DEBUG)


DATA_SLICE = slice(10)


Expand Down Expand Up @@ -811,6 +819,35 @@ def test_explain_model_keras_classifier(self, iris, tabular_explainer):
assert len(local_importance_values) == num_classes, ('length of local importances does not match number '
'of classes')

def test_explain_model_batch_dataset(self, housing, tabular_explainer):
X_train = housing[DatasetConstants.X_TRAIN]
X_test = housing[DatasetConstants.X_TEST][DATA_SLICE]
y_train = housing[DatasetConstants.Y_TRAIN]
y_test = housing[DatasetConstants.Y_TEST][DATA_SLICE]
features = housing[DatasetConstants.FEATURES]
X_train_df = pd.DataFrame(X_train, columns=list(features))
X_test_df = pd.DataFrame(X_test, columns=list(features))
inp = (dict(X_train_df), y_train)
inp_ds = tf.data.Dataset.from_tensor_slices(inp).batch(32)
val = (dict(X_test_df), y_test)
val_ds = tf.data.Dataset.from_tensor_slices(val).batch(32)
model = create_tf_model(inp_ds, val_ds, features)
wrapped_dataset = DatasetWrapper(val_ds)
wrapped_model = wrap_model(model, wrapped_dataset, model_task='regression')

explainer = tabular_explainer(wrapped_model, inp_ds)

global_explanation = explainer.explain_global(wrapped_dataset)
local_explanation = explainer.explain_local(wrapped_dataset)
global_importance_values = global_explanation.global_importance_values
num_rows = X_test_df.shape[0]
num_feats = X_test_df.shape[1]
assert len(global_importance_values) == num_feats, ('length of global importances '
'does not match number of features')
local_importance_values = local_explanation.local_importance_values
assert len(local_importance_values) == num_rows, ('length of local importances does not match number '
'of rows')

def verify_adult_overall_features(self, ranked_global_names, ranked_global_values):
# Verify order of features
test_logger.info(ranked_global_names)
Expand Down

0 comments on commit 1f2d978

Please sign in to comment.