## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#fromenumimportEnumimportinspectimportfunctoolsimportosfromtypingimport(Any,Callable,Dict,Optional,List,Sequence,TYPE_CHECKING,cast,TypeVar,Union,)# For backward compatibility.frompyspark.errorsimport(# noqa: F401AnalysisException,ParseException,IllegalArgumentException,StreamingQueryException,QueryExecutionException,PythonException,UnknownException,SparkUpgradeException,PySparkNotImplementedError,PySparkRuntimeError,)frompyspark.utilimportis_remote_only,JVM_INT_MAXfrompyspark.errors.exceptions.capturedimportCapturedException# noqa: F401frompyspark.find_spark_homeimport_find_spark_homeifTYPE_CHECKING:frompy4j.java_collectionsimportJavaArrayfrompy4j.java_gatewayimport(JavaClass,JavaGateway,JavaObject,JVMView,)frompysparkimportSparkContextfrompyspark.sql.sessionimportSparkSessionfrompyspark.sql.dataframeimportDataFramefrompyspark.pandas._typingimportIndexOpsLike,SeriesOrIndexhas_numpy:bool=Falsetry:importnumpyasnp# noqa: F401has_numpy=TrueexceptImportError:passFuncT=TypeVar("FuncT",bound=Callable[...,Any])defto_java_array(gateway:"JavaGateway",jtype:"JavaClass",arr:Sequence[Any])->"JavaArray":""" Convert python list to java type array Parameters ---------- gateway : Py4j Gateway jtype : java type of element in array arr : python type list """jarray:"JavaArray"=gateway.new_array(jtype,len(arr))foriinrange(0,len(arr)):jarray[i]=arr[i]returnjarraydefto_scala_map(jvm:"JVMView",dic:Dict)->"JavaObject":""" Convert a dict into a Scala Map. """assertjvmisnotNonereturnjvm.PythonUtils.toScalaMap(dic)defrequire_test_compiled()->None:"""Raise Exception if test classes are not compiled"""importosimportglobtest_class_path=os.path.join(_find_spark_home(),"sql","core","target","*","test-classes")paths=glob.glob(test_class_path)iflen(paths)==0:raisePySparkRuntimeError(errorClass="TEST_CLASS_NOT_COMPILED",messageParameters={"test_class_path":test_class_path},)classForeachBatchFunction:""" This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps the user-defined 'foreachBatch' function such that it can be called from the JVM when the query is active. """def__init__(self,session:"SparkSession",func:Callable[["DataFrame",int],None]):self.func=funcself.session=sessiondefcall(self,jdf:"JavaObject",batch_id:int)->None:frompyspark.sql.dataframeimportDataFramefrompyspark.sql.sessionimportSparkSessiontry:session_jdf=jdf.sparkSession()# assuming that spark context is still the same between JVM and PySparkwrapped_session_jdf=SparkSession(self.session.sparkContext,session_jdf)self.func(DataFrame(jdf,wrapped_session_jdf),batch_id)exceptExceptionase:self.error=eraiseeclassJava:implements=["org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction"]# Python implementation of 'org.apache.spark.sql.catalyst.util.StringConcat'classStringConcat:def__init__(self,maxLength:int=JVM_INT_MAX-15):self.maxLength:int=maxLengthself.strings:List[str]=[]self.length:int=0defatLimit(self)->bool:returnself.length>=self.maxLengthdefappend(self,s:str)->None:ifsisnotNone:sLen=len(s)ifnotself.atLimit():available=self.maxLength-self.lengthstringToAppend=sifavailable>=sLenelses[0:available]self.strings.append(stringToAppend)self.length=min(self.length+sLen,JVM_INT_MAX-15)deftoString(self)->str:# finalLength = self.maxLength if self.atLimit() else self.lengthreturn"".join(self.strings)# Python implementation of 'org.apache.spark.util.SparkSchemaUtils.escapeMetaCharacters'defescape_meta_characters(s:str)->str:return(s.replace("\n","\\n").replace("\r","\\r").replace("\t","\\t").replace("\f","\\f").replace("\b","\\b").replace("\u000B","\\v").replace("\u0007","\\a"))defto_str(value:Any)->Optional[str]:""" A wrapper over str(), but converts bool values to lower case strings. If None is given, just returns None, instead of converting it to string "None". """ifisinstance(value,bool):returnstr(value).lower()elifvalueisNone:returnvalueelse:returnstr(value)defenum_to_value(value:Any)->Any:"""Convert an Enum to its value if it is not None."""returnenum_to_value(value.value)ifvalueisnotNoneandisinstance(value,Enum)elsevaluedefis_timestamp_ntz_preferred()->bool:""" Return a bool if TimestampNTZType is preferred according to the SQL configuration set. """ifis_remote():frompyspark.sql.connect.sessionimportSparkSessionasConnectSparkSessionsession=ConnectSparkSession.getActiveSession()ifsessionisNone:returnFalseelse:returnsession.conf.get("spark.sql.timestampType",None)=="TIMESTAMP_NTZ"else:frompysparkimportSparkContextjvm=SparkContext._jvmreturnjvmisnotNoneandjvm.PythonSQLUtils.isTimestampNTZPreferred()
[docs]defis_remote()->bool:""" Returns if the current running environment is for Spark Connect. .. versionadded:: 4.0.0 Notes ----- This will only return ``True`` if there is a remote session running. Otherwise, it returns ``False``. This API is unstable, and for developers. Returns ------- bool Examples -------- >>> from pyspark.sql import is_remote >>> is_remote() False """return("SPARK_CONNECT_MODE_ENABLED"inos.environ)oris_remote_only()
deftry_remote_functions(f:FuncT)->FuncT:"""Mark API supported from Spark Connect."""@functools.wraps(f)defwrapped(*args:Any,**kwargs:Any)->Any:ifis_remote()and"PYSPARK_NO_NAMESPACE_SHARE"notinos.environ:frompyspark.sql.connectimportfunctionsreturngetattr(functions,f.__name__)(*args,**kwargs)else:returnf(*args,**kwargs)returncast(FuncT,wrapped)deftry_partitioning_remote_functions(f:FuncT)->FuncT:"""Mark API supported from Spark Connect."""@functools.wraps(f)defwrapped(*args:Any,**kwargs:Any)->Any:ifis_remote()and"PYSPARK_NO_NAMESPACE_SHARE"notinos.environ:frompyspark.sql.connect.functionsimportpartitioningreturngetattr(partitioning,f.__name__)(*args,**kwargs)else:returnf(*args,**kwargs)returncast(FuncT,wrapped)deftry_remote_avro_functions(f:FuncT)->FuncT:"""Mark API supported from Spark Connect."""@functools.wraps(f)defwrapped(*args:Any,**kwargs:Any)->Any:ifis_remote()and"PYSPARK_NO_NAMESPACE_SHARE"notinos.environ:frompyspark.sql.connect.avroimportfunctionsreturngetattr(functions,f.__name__)(*args,**kwargs)else:returnf(*args,**kwargs)returncast(FuncT,wrapped)deftry_remote_protobuf_functions(f:FuncT)->FuncT:"""Mark API supported from Spark Connect."""@functools.wraps(f)defwrapped(*args:Any,**kwargs:Any)->Any:ifis_remote()and"PYSPARK_NO_NAMESPACE_SHARE"notinos.environ:frompyspark.sql.connect.protobufimportfunctionsreturngetattr(functions,f.__name__)(*args,**kwargs)else:returnf(*args,**kwargs)returncast(FuncT,wrapped)defget_active_spark_context()->"SparkContext":"""Raise RuntimeError if SparkContext is not initialized, otherwise, returns the active SparkContext."""frompysparkimportSparkContextsc=SparkContext._active_spark_contextifscisNoneorsc._jvmisNone:raisePySparkRuntimeError(errorClass="SESSION_OR_CONTEXT_NOT_EXISTS",messageParameters={},)returnscdeftry_remote_session_classmethod(f:FuncT)->FuncT:"""Mark API supported from Spark Connect."""@functools.wraps(f)defwrapped(*args:Any,**kwargs:Any)->Any:ifis_remote()and"PYSPARK_NO_NAMESPACE_SHARE"notinos.environ:frompyspark.sql.connect.sessionimportSparkSessionassertinspect.isclass(args[0])returngetattr(SparkSession,f.__name__)(*args[1:],**kwargs)else:returnf(*args,**kwargs)returncast(FuncT,wrapped)defdispatch_df_method(f:FuncT)->FuncT:""" For the use cases of direct DataFrame.method(df, ...), it checks if self is a Connect DataFrame or Classic DataFrame, and dispatches. """@functools.wraps(f)defwrapped(*args:Any,**kwargs:Any)->Any:ifis_remote()and"PYSPARK_NO_NAMESPACE_SHARE"notinos.environ:frompyspark.sql.connect.dataframeimportDataFrameasConnectDataFrameifisinstance(args[0],ConnectDataFrame):returngetattr(ConnectDataFrame,f.__name__)(*args,**kwargs)else:frompyspark.sql.classic.dataframeimportDataFrameasClassicDataFrameifisinstance(args[0],ClassicDataFrame):returngetattr(ClassicDataFrame,f.__name__)(*args,**kwargs)raisePySparkNotImplementedError(errorClass="NOT_IMPLEMENTED",messageParameters={"feature":f"DataFrame.{f.__name__}"},)returncast(FuncT,wrapped)defdispatch_col_method(f:FuncT)->FuncT:""" For the use cases of direct Column.method(col, ...), it checks if self is a Connect Column or Classic Column, and dispatches. """@functools.wraps(f)defwrapped(*args:Any,**kwargs:Any)->Any:ifis_remote()and"PYSPARK_NO_NAMESPACE_SHARE"notinos.environ:frompyspark.sql.connect.columnimportColumnasConnectColumnifisinstance(args[0],ConnectColumn):returngetattr(ConnectColumn,f.__name__)(*args,**kwargs)else:frompyspark.sql.classic.columnimportColumnasClassicColumnifisinstance(args[0],ClassicColumn):returngetattr(ClassicColumn,f.__name__)(*args,**kwargs)raisePySparkNotImplementedError(errorClass="NOT_IMPLEMENTED",messageParameters={"feature":f"Column.{f.__name__}"},)returncast(FuncT,wrapped)defdispatch_window_method(f:FuncT)->FuncT:""" For use cases of direct Window.method(col, ...), this function dispatches the call to either ConnectWindow or ClassicWindow based on the execution environment. """@functools.wraps(f)defwrapped(*args:Any,**kwargs:Any)->Any:ifis_remote()and"PYSPARK_NO_NAMESPACE_SHARE"notinos.environ:frompyspark.sql.connect.windowimportWindowasConnectWindowreturngetattr(ConnectWindow,f.__name__)(*args,**kwargs)else:frompyspark.sql.classic.windowimportWindowasClassicWindowreturngetattr(ClassicWindow,f.__name__)(*args,**kwargs)returncast(FuncT,wrapped)defpyspark_column_op(func_name:str,left:"IndexOpsLike",right:Any,fillna:Any=None)->Union["SeriesOrIndex",None]:""" Wrapper function for column_op to get proper Column class. """frompyspark.pandas.baseimportcolumn_opfrompyspark.sql.columnimportColumnfrompyspark.pandas.data_type_ops.baseimport_is_extension_dtypesresult=column_op(getattr(Column,func_name))(left,right)# It works as expected on extension dtype, so we don't need to call `fillna` for this case.if(fillnaisnotNone)and(_is_extension_dtypes(left)or_is_extension_dtypes(right)):fillna=None# TODO(SPARK-43877): Fix behavior difference for compare binary functions.returnresult.fillna(fillna)iffillnaisnotNoneelseresultdefget_lit_sql_str(val:str)->str:# Equivalent to `lit(val)._jc.expr().sql()` for string typed val# See `sql` definition in `sql/catalyst/src/main/scala/org/apache/spark/# sql/catalyst/expressions/literals.scala`return"'"+val.replace("\\","\\\\").replace("'","\\'")+"'"