#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.sql.tests.pandas.streaming.test_pandas_transform_with_state_state_variable import (
    TransformWithStateInPandasStateVariableTestsMixin,
)
from pyspark import SparkConf
from pyspark.testing.connectutils import ReusedConnectTestCase


class TransformWithStateInPandasStateVariableParityTests(
    TransformWithStateInPandasStateVariableTestsMixin, ReusedConnectTestCase
):
    """
    Spark connect parity tests for TransformWithStateInPandas. Run every test case in
     `TransformWithStateInPandasStateVariableTestsMixin` in spark connect mode.
    """

    @classmethod
    def conf(cls):
        # Due to multiple inheritance from the same level, we need to explicitly setting configs in
        # both TransformWithStateInPandasStateVariableTestsMixin and ReusedConnectTestCase here
        cfg = SparkConf(loadDefaults=False)
        for base in cls.__bases__:
            if hasattr(base, "conf"):
                parent_cfg = base.conf()
                for k, v in parent_cfg.getAll():
                    cfg.set(k, v)

        # Extra removing config for connect suites
        if cfg._jconf is not None:
            cfg._jconf.remove("spark.master")

        return cfg


if __name__ == "__main__":
    from pyspark.testing import main

    main()
