You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "Borys Biletskyy (Jira)" <ji...@apache.org> on 2019/10/09 14:43:00 UTC

[jira] [Created] (SPARK-29414) HasOutputCol param isSet() property is not preserved after persistence

Borys Biletskyy created SPARK-29414:
---------------------------------------

             Summary: HasOutputCol param isSet() property is not preserved after persistence
                 Key: SPARK-29414
                 URL: https://issues.apache.org/jira/browse/SPARK-29414
             Project: Spark
          Issue Type: Bug
          Components: ML, PySpark
    Affects Versions: 2.3.2
            Reporter: Borys Biletskyy


HasOutputCol param isSet() property is not preserved after saving and loading using DefaultParamsReadable and DefaultParamsWritable.
{code:java}
import pytest
from pyspark import keyword_only
from pyspark.ml import Model
from pyspark.sql import DataFrame
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
from pyspark.sql.functions import *


class HasOutputColTester(Model,
                         HasInputCol,
                         HasOutputCol,
                         DefaultParamsReadable,
                         DefaultParamsWritable
                         ):
    @keyword_only
    def __init__(self, inputCol: str = None, outputCol: str = None):
        super(HasOutputColTester, self).__init__()
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(self, inputCol: str = None, outputCol: str = None):
        kwargs = self._input_kwargs
        self._set(**kwargs)
        return self

    def _transform(self, data: DataFrame) -> DataFrame:
        return data


class TestHasInputColParam(object):
    def test_persist_input_col_set(self, spark, temp_dir):
        path = temp_dir + '/test_model'
        model = HasOutputColTester()
        assert not model.isDefined(model.inputCol)
        assert not model.isSet(model.inputCol)

        assert model.isDefined(model.outputCol)
        assert not model.isSet(model.outputCol)
        model.write().overwrite().save(path)

        loaded_model: HasOutputColTester = HasOutputColTester.load(path)
        assert not loaded_model.isDefined(model.inputCol)
        assert not loaded_model.isSet(model.inputCol)

        assert loaded_model.isDefined(model.outputCol)
        assert not loaded_model.isSet(model.outputCol)  # AssertionError: assert not True
{code}



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org