Использование предиктора Sagemaker в функции Spark UDF

Я пытаюсь выполнить вывод для модели Tensorflow, развернутой в SageMaker из задания Python Spark. Я запускаю блокнот (Databricks) со следующей ячейкой:

def call_predict():
        batch_size = 1
        data = [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2]]
        tensor_proto = tf.make_tensor_proto(values=np.asarray(data), shape=[batch_size, len(data[0])], dtype=tf.float32)      
        prediction = predictor.predict(tensor_proto)
        print("Process time: {}".format((time.clock() - start)))
        return prediction

Если я просто вызываю call_predict (), он работает нормально:

call_predict()

и я получаю вывод:

Process time: 65.261396
Out[61]: {'model_spec': {'name': u'generic_model',
  'signature_name': u'serving_default',
  'version': {'value': 1578909324L}},
 'outputs': {u'ages': {'dtype': 1,
   'float_val': [5.680944442749023],
   'tensor_shape': {'dim': [{'size': 1L}]}}}}

но когда я пытаюсь позвонить из контекста Spark (в UDF), я получаю ошибку сериализации. Код, который я пытаюсь запустить:

dataRange = range(1, 10001)
rangeRDD = sc.parallelize(dataRange, 8)
new_data = rangeRDD.map(lambda x : call_predict())
new_data.count()

и ошибка, которую я получаю:

---------------------------------------------------------------------------
PicklingError                             Traceback (most recent call last)
<command-2282434> in <module>()
      2 rangeRDD = sc.parallelize(dataRange, 8)
      3 new_data = rangeRDD.map(lambda x : call_predict())
----> 4 new_data.count()
      5 

/databricks/spark/python/pyspark/rdd.pyc in count(self)
   1094         3
   1095         """
-> 1096         return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
   1097 
   1098     def stats(self):

/databricks/spark/python/pyspark/rdd.pyc in sum(self)
   1085         6.0
   1086         """
-> 1087         return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add)
   1088 
   1089     def count(self):

/databricks/spark/python/pyspark/rdd.pyc in fold(self, zeroValue, op)
    956         # zeroValue provided to each partition is unique from the one provided
    957         # to the final reduce call
--> 958         vals = self.mapPartitions(func).collect()
    959         return reduce(op, vals, zeroValue)
    960 

/databricks/spark/python/pyspark/rdd.pyc in collect(self)
    829         # Default path used in OSS Spark / for non-credential passthrough clusters:
    830         with SCCallSiteSync(self.context) as css:
--> 831             sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
    832         return list(_load_from_socket(sock_info, self._jrdd_deserializer))
    833 

/databricks/spark/python/pyspark/rdd.pyc in _jrdd(self)
   2573 
   2574         wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer,
-> 2575                                       self._jrdd_deserializer, profiler)
   2576         python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func,
   2577                                              self.preservesPartitioning, self.is_barrier)

/databricks/spark/python/pyspark/rdd.pyc in _wrap_function(sc, func, deserializer, serializer, profiler)
   2475     assert serializer, "serializer should not be empty"
   2476     command = (func, profiler, deserializer, serializer)
-> 2477     pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
   2478     return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
   2479                                   sc.pythonVer, broadcast_vars, sc._javaAccumulator)

/databricks/spark/python/pyspark/rdd.pyc in _prepare_for_python_RDD(sc, command)
   2461     # the serialized command will be compressed by broadcast
   2462     ser = CloudPickleSerializer()
-> 2463     pickled_command = ser.dumps(command)
   2464     if len(pickled_command) > sc._jvm.PythonUtils.getBroadcastThreshold(sc._jsc):  # Default 1M
   2465         # The broadcast will have same life cycle as created PythonRDD

/databricks/spark/python/pyspark/serializers.pyc in dumps(self, obj)
    709                 msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
    710             cloudpickle.print_exec(sys.stderr)
--> 711             raise pickle.PicklingError(msg)
    712 
    713 

PicklingError: Could not serialize object: TypeError: can't pickle _ssl._SSLSocket objects

Не уверен, что это за ошибка сериализации - это жалоба на неудачу при десериализации Predictor

В моей записной книжке есть ячейка, которая была вызвана перед указанными выше ячейками со следующим импортом:

import sagemaker
import boto3
from sagemaker.tensorflow.model import TensorFlowPredictor
import tensorflow as tf
import numpy as np
import time

Предиктор был создан со следующим кодом:

sagemaker_client = boto3.client('sagemaker', aws_access_key_id=ACCESS_KEY,
                                aws_secret_access_key=SECRET_KEY, region_name='us-east-1')
sagemaker_runtime_client = boto3.client('sagemaker-runtime', aws_access_key_id=ACCESS_KEY,
                                        aws_secret_access_key=SECRET_KEY, region_name='us-east-1')

boto_session = boto3.Session(region_name='us-east-1')
sagemaker_session = sagemaker.Session(boto_session, sagemaker_client=sagemaker_client, sagemaker_runtime_client=sagemaker_runtime_client)

predictor = TensorFlowPredictor('endpoint-poc', sagemaker_session)

Всего 1 ответ


Функция udf будет выполняться несколькими искровыми задачами параллельно. Эти задачи выполняются в полностью изолированных процессах Python, и они запланированы на физически разных машинах. Следовательно, все данные, ссылки на эти функции должны быть на этом узле. Это относится ко всему, что создано в udf.

Всякий раз, когда вы ссылаетесь на какой-либо объект вне udf из функции, эта структура данных должна быть сериализована (протравлена) каждому исполнителю. Это прекрасно работает для небольших наборов данных, но почти всегда терпит неудачу для больших моделей нейронных сетей.

Вы должны убедиться, что модель tenorflow создается на каждом исполнителе путем ленивой загрузки. Это должно произойти только при первом вызове функции этого исполнителя. Делать это при каждом обращении к udf будет ужасно медленно.

Обычно для этого можно использовать шаблон Singleton. Но в питоне люди используют шаблон Борга.

class Env:
    _shared_state = {
        "sagemaker_client": None
        "sagemaker_runtime_client": None
        "boto_session": None
        "sagemaker_session": None
        "predictor": None
    }
    def __init__(self):
        self.__dict__ = self._shared_state
        if not self.predictor:
            self.sagemaker_client = boto3.client('sagemaker', aws_access_key_id=ACCESS_KEY, aws_secret_access_key=SECRET_KEY, region_name='us-east-1')
            self.sagemaker_runtime_client = boto3.client('sagemaker-runtime', aws_access_key_id=ACCESS_KEY, aws_secret_access_key=SECRET_KEY, region_name='us-east-1')

            self.boto_session = boto3.Session(region_name='us-east-1')
            self.sagemaker_session = sagemaker.Session(self.boto_session, sagemaker_client=self.sagemaker_client, sagemaker_runtime_client=self.sagemaker_runtime_client)

            self.predictor = TensorFlowPredictor('endpoint-poc', self.sagemaker_session)


#....
def call_predict():
   env = Env()
   batch_size = 1
   data = [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2]]
   tensor_proto = tf.make_tensor_proto(values=np.asarray(data), shape=[batch_size, len(data[0])], dtype=tf.float32)      
   prediction = env.predictor.predict(tensor_proto)

   print("Process time: {}".format((time.clock() - start)))
        return prediction

new_data = rangeRDD.map(lambda x : call_predict())

Класс Env определен на главном узле. В его _shared_state есть пустые записи. Когда класс Env создается один раз, он делит состояние со всеми последующими экземплярами Env при любом последующем обращении к udf. На каждом отдельном параллельном процессе это происходит ровно один раз. Таким образом, сеанс является общим и не нужно мариновать.