You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2021/03/10 01:14:51 UTC

[GitHub] [madlib] kaknikhil commented on pull request #558: DL: Fix num_class parsing from model architecture

kaknikhil commented on pull request #558:
URL: https://github.com/apache/madlib/pull/558#issuecomment-794693836


   changes LGTM but I would suggest adding a few unit tests to InputValidatorTestCase to make sure it is well tested since this code also affects non multi io code.
   
   ```
       def test_validate_class_values_last_layer_not_dense(self):
           num_classes = 3
           model = Sequential()
           model.add(Conv2D(2, kernel_size=(1, 1), activation='relu',
                            input_shape=(1,1,1,), padding='same'))
           model.add(Dense(num_classes))
           model.add(Activation('relu'))
           model.add(Activation('softmax'))
   
           self.subject.validate_class_values(
               self.module_name, [range(num_classes)], 'prob', model.to_json())
   
       def test_validate_class_values_last_layer_not_dense_multiio(self):
           num_classes = 3
           model = Sequential()
           model.add(Conv2D(2, kernel_size=(1, 1), activation='relu',
                            input_shape=(1,1,1,), padding='same'))
           model.add(Dense(num_classes))
           model.add(Dense(num_classes))
           model.add(Activation('relu'))
           model.add(Activation('softmax'))
   
           self.subject.validate_class_values(
               self.module_name, [range(num_classes), range(num_classes)], 'prob', model.to_json())
   
       def test_validate_class_values_mismatch(self):
           expected_error_regex = ".*do not match.*architecture"
           num_classes = 3
   
           # only one dense layer but len(dep_var) = 2
           model = Sequential()
           model.add(Conv2D(2, kernel_size=(1, 1), activation='relu',
                            input_shape=(1,1,1,), padding='same'))
           model.add(Dense(num_classes))
           model.add(Activation('relu'))
           with self.assertRaisesRegexp(plpy.PLPYException, expected_error_regex):
               self.subject.validate_class_values(
                   self.module_name, [range(num_classes), range(num_classes)], 'prob', model.to_json())
   
           # two dense layers
           model = Sequential()
           model.add(Conv2D(2, kernel_size=(1, 1), activation='relu',
                            input_shape=(1,1,1,), padding='same'))
           model.add(Dense(2))
           model.add(Dense(num_classes))
           model.add(Activation('relu'))
           with self.assertRaisesRegexp(plpy.PLPYException, expected_error_regex):
               self.subject.validate_class_values(
                   self.module_name, [range(num_classes), range(num_classes)], 'prob', model.to_json())
   
       def test_validate_class_values_no_units(self):
           expected_error_regex = ".*Unable.*classes.*architecture"
           num_classes = 3
           #model arch is missing a dense layer
           model = Sequential()
           model.add(Conv2D(2, kernel_size=(1, 1), activation='relu',
                            input_shape=(1,1,1,), padding='same'))
           with self.assertRaisesRegexp(plpy.PLPYException, expected_error_regex):
               self.subject.validate_class_values(
                   self.module_name, [range(num_classes)], 'prob', model.to_json())
           with self.assertRaisesRegexp(plpy.PLPYException, expected_error_regex):
               self.subject.validate_class_values(
                   self.module_name, [range(num_classes), range(num_classes)], 'prob', model.to_json())
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org