## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#fromabcimportABCMeta,abstractmethodimportcopyimportthreadingfromtypingimport(Any,Callable,Generic,Iterator,List,Optional,Sequence,Tuple,TypeVar,Union,cast,overload,TYPE_CHECKING,)frompysparkimportsincefrompyspark.ml.paramimportPfrompyspark.ml.commonimportinherit_docfrompyspark.ml.param.sharedimport(HasInputCol,HasOutputCol,HasLabelCol,HasFeaturesCol,HasPredictionCol,Params,)frompyspark.sql.dataframeimportDataFramefrompyspark.sql.functionsimportudffrompyspark.sql.typesimportDataType,StructField,StructTypeifTYPE_CHECKING:frompyspark.ml._typingimportParamMapT=TypeVar("T")M=TypeVar("M",bound="Transformer")class_FitMultipleIterator(Generic[M]):""" Used by default implementation of Estimator.fitMultiple to produce models in a thread safe iterator. This class handles the simple case of fitMultiple where each param map should be fit independently. Parameters ---------- fitSingleModel : function Callable[[int], Transformer] which fits an estimator to a dataset. `fitSingleModel` may be called up to `numModels` times, with a unique index each time. Each call to `fitSingleModel` with an index should return the Model associated with that index. numModel : int Number of models this iterator should produce. Notes ----- See :py:meth:`Estimator.fitMultiple` for more info. """def__init__(self,fitSingleModel:Callable[[int],M],numModels:int):""" """self.fitSingleModel=fitSingleModelself.numModel=numModelsself.counter=0self.lock=threading.Lock()def__iter__(self)->Iterator[Tuple[int,M]]:returnselfdef__next__(self)->Tuple[int,M]:withself.lock:index=self.counterifindex>=self.numModel:raiseStopIteration("No models remaining.")self.counter+=1returnindex,self.fitSingleModel(index)defnext(self)->Tuple[int,M]:"""For python2 compatibility."""returnself.__next__()
[docs]@inherit_docclassEstimator(Params,Generic[M],metaclass=ABCMeta):""" Abstract class for estimators that fit models to data. .. versionadded:: 1.3.0 """@abstractmethoddef_fit(self,dataset:DataFrame)->M:""" Fits a model to the input dataset. This is called by the default implementation of fit. Parameters ---------- dataset : :py:class:`pyspark.sql.DataFrame` input dataset Returns ------- :class:`Transformer` fitted model """raiseNotImplementedError()
[docs]deffitMultiple(self,dataset:DataFrame,paramMaps:Sequence["ParamMap"])->Iterator[Tuple[int,M]]:""" Fits a model to the input dataset for each param map in `paramMaps`. .. versionadded:: 2.3.0 Parameters ---------- dataset : :py:class:`pyspark.sql.DataFrame` input dataset. paramMaps : :py:class:`collections.abc.Sequence` A Sequence of param maps. Returns ------- :py:class:`_FitMultipleIterator` A thread safe iterable which contains one model for each param map. Each call to `next(modelIterator)` will return `(index, model)` where model was fit using `paramMaps[index]`. `index` values may not be sequential. """estimator=self.copy()deffitSingleModel(index:int)->M:returnestimator.fit(dataset,paramMaps[index])return_FitMultipleIterator(fitSingleModel,len(paramMaps))
[docs]deffit(self,dataset:DataFrame,params:Optional[Union["ParamMap",List["ParamMap"],Tuple["ParamMap"]]]=None,)->Union[M,List[M]]:""" Fits a model to the input dataset with optional parameters. .. versionadded:: 1.3.0 Parameters ---------- dataset : :py:class:`pyspark.sql.DataFrame` input dataset. params : dict or list or tuple, optional an optional param map that overrides embedded params. If a list/tuple of param maps is given, this calls fit on each param map and returns a list of models. Returns ------- :py:class:`Transformer` or a list of :py:class:`Transformer` fitted model(s) """ifparamsisNone:params=dict()ifisinstance(params,(list,tuple)):models:List[Optional[M]]=[None]*len(params)forindex,modelinself.fitMultiple(dataset,params):models[index]=modelreturncast(List[M],models)elifisinstance(params,dict):ifparams:returnself.copy(params)._fit(dataset)else:returnself._fit(dataset)else:raiseTypeError("Params must be either a param map or a list/tuple of param maps, ""but got %s."%type(params))
[docs]@inherit_docclassTransformer(Params,metaclass=ABCMeta):""" Abstract class for transformers that transform one dataset into another. .. versionadded:: 1.3.0 """@abstractmethoddef_transform(self,dataset:DataFrame)->DataFrame:""" Transforms the input dataset. Parameters ---------- dataset : :py:class:`pyspark.sql.DataFrame` input dataset. Returns ------- :py:class:`pyspark.sql.DataFrame` transformed dataset """raiseNotImplementedError()
[docs]deftransform(self,dataset:DataFrame,params:Optional["ParamMap"]=None)->DataFrame:""" Transforms the input dataset with optional parameters. .. versionadded:: 1.3.0 Parameters ---------- dataset : :py:class:`pyspark.sql.DataFrame` input dataset params : dict, optional an optional param map that overrides embedded params. Returns ------- :py:class:`pyspark.sql.DataFrame` transformed dataset """ifparamsisNone:params=dict()ifisinstance(params,dict):ifparams:returnself.copy(params)._transform(dataset)else:returnself._transform(dataset)else:raiseTypeError("Params must be a param map but got %s."%type(params))
[docs]@inherit_docclassModel(Transformer,metaclass=ABCMeta):""" Abstract class for models that are fitted by estimators. .. versionadded:: 1.4.0 """pass
[docs]@inherit_docclassUnaryTransformer(HasInputCol,HasOutputCol,Transformer):""" Abstract class for transformers that take one input column, apply transformation, and output the result as a new column. .. versionadded:: 2.3.0 """
[docs]defsetInputCol(self:P,value:str)->P:""" Sets the value of :py:attr:`inputCol`. """returnself._set(inputCol=value)
[docs]defsetOutputCol(self:P,value:str)->P:""" Sets the value of :py:attr:`outputCol`. """returnself._set(outputCol=value)
[docs]@abstractmethoddefcreateTransformFunc(self)->Callable[...,Any]:""" Creates the transform function using the given param map. The input param map already takes account of the embedded param map. So the param values should be determined solely by the input param map. """raiseNotImplementedError()
[docs]@abstractmethoddefoutputDataType(self)->DataType:""" Returns the data type of the output column. """raiseNotImplementedError()
[docs]@abstractmethoddefvalidateInputType(self,inputType:DataType)->None:""" Validates the input type. Throw an exception if it is invalid. """raiseNotImplementedError()
@inherit_docclass_PredictorParams(HasLabelCol,HasFeaturesCol,HasPredictionCol):""" Params for :py:class:`Predictor` and :py:class:`PredictorModel`. .. versionadded:: 3.0.0 """pass
[docs]@inherit_docclassPredictor(Estimator[M],_PredictorParams,metaclass=ABCMeta):""" Estimator for prediction tasks (regression and classification). """
[docs]@since("3.0.0")defsetLabelCol(self:P,value:str)->P:""" Sets the value of :py:attr:`labelCol`. """returnself._set(labelCol=value)
[docs]@since("3.0.0")defsetFeaturesCol(self:P,value:str)->P:""" Sets the value of :py:attr:`featuresCol`. """returnself._set(featuresCol=value)
[docs]@since("3.0.0")defsetPredictionCol(self:P,value:str)->P:""" Sets the value of :py:attr:`predictionCol`. """returnself._set(predictionCol=value)
[docs]@inherit_docclassPredictionModel(Model,_PredictorParams,Generic[T],metaclass=ABCMeta):""" Model for prediction tasks (regression and classification). """
[docs]@since("3.0.0")defsetFeaturesCol(self:P,value:str)->P:""" Sets the value of :py:attr:`featuresCol`. """returnself._set(featuresCol=value)
[docs]@since("3.0.0")defsetPredictionCol(self:P,value:str)->P:""" Sets the value of :py:attr:`predictionCol`. """returnself._set(predictionCol=value)
@property@abstractmethod@since("2.1.0")defnumFeatures(self)->int:""" Returns the number of features the model was trained on. If unknown, returns -1 """raiseNotImplementedError()
[docs]@abstractmethod@since("3.0.0")defpredict(self,value:T)->float:""" Predict label for the given features. """raiseNotImplementedError()