Machine Learning
[tensorflow in spark] spark를 이용해 tf model을 분산 처리?!
Dan-k
2024. 3. 8. 14:54
반응형
tensorflow 모델을 spark 분산환경을 위해서 inference하는 방법
- spark udf (pandas udf)를 활용해서 각 worker node에 inference 코드를 전달!!
방안1) driver에서 모델을 로드해서 pandas udf에 모델을 주는 방법
- 이 방법은 pickling error가 나옴
- 이유는 driver에서 worker node에 udf를 이용해 파일/데이터를 전달할때 pickle화 시키는 것으로 보임
-> pickle화 시키는 과정에서 tensorflow model을 pickling하지 못해서 나오는 이슈로 보임
- 아래 코드는 error가 나올 것 : PicklingError: Can't pickle...
# 데이터 생성
data = [("row1", 1.0, 2.0, 3.0),
("row2", 4.0, 5.0, 6.0)]
columns = ["id", "feature1", "feature2", "feature3"]
df = spark.createDataFrame(data, columns)
features = ["feature2", "feature3"]
model = tf.keras.models.load_model(model_path)
@F.pandas_udf(returnType=ArrayType(DoubleType()))
def predict_pandas_udf(*cols):
import numpy as np
import pandas as pd
df = pd.concat(cols, axis=1, keys=features)
y_pred = model(dict(df))
return pd.Series(list(tensor.numpy() for tensor in y_pred))
pred_sdf = (
sdf.select(
F.col("id"),
predict_pandas_udf(*features).alias("predictions"),
)
방안2) woker 노드에서 직접 모델을 로드해서 pandas udf로 inference하는 방안
- worker 노드에서 직접 모델 path를 참고하여 로드하기 때문에 정상 동작
- driver에서 리소스를 관리하는 방향에 대해서 올바른 방안인지는 의문이 생김
# 데이터 생성
data = [("row1", 1.0, 2.0, 3.0),
("row2", 4.0, 5.0, 6.0)]
columns = ["id", "feature1", "feature2", "feature3"]
df = spark.createDataFrame(data, columns)
features = ["feature2", "feature3"]
@F.pandas_udf(returnType=ArrayType(DoubleType()))
def predict_pandas_udf(*cols):
import numpy as np
import pandas as pd
df = pd.concat(cols, axis=1, keys=features)
model = tf.keras.models.load_model(model_path)
return pd.Series(list(tensor.numpy() for tensor in model(dict(df))))
pred_sdf = (
sdf.select(
F.col("id"),
predict_pandas_udf(*features).alias("predictions"),
)
multi class 분류의 경우
predictions의 컬럼값이 아래와 같이 array형태로 들어가게 될텐데,,
[[0.3, 0.4, 0.3],
[0.2, 0.2, 0.6],
[0.3, 0.3, 0.3], ...]
이를 class별로 풀려면 아래와 같이!!
pred_sdf = (
sdf.select(
F.col("id"),
predict_pandas_udf(*features).alias("predictions"),
)
.select(
F.col("id"),
*[
F.col("predictions")[i].alias(f"Class_{c}")
for i, c in enumerate(range(num_classes))
],
)
)
print(pred_sdf.take(10))
728x90
반응형
LIST