데이터과학 삼학년

[tensorflow in spark] spark를 이용해 tf model을 분산 처리?! 본문

Machine Learning

[tensorflow in spark] spark를 이용해 tf model을 분산 처리?!

Dan-k 2024. 3. 8. 14:54
반응형

tensorflow 모델을 spark 분산환경을 위해서 inference하는 방법

 

- spark udf (pandas udf)를 활용해서 각 worker node에 inference 코드를 전달!!

 

방안1) driver에서 모델을 로드해서 pandas udf에 모델을 주는 방법

- 이 방법은 pickling error가 나옴

- 이유는 driver에서 worker node에 udf를 이용해 파일/데이터를 전달할때 pickle화 시키는 것으로 보임

-> pickle화 시키는 과정에서 tensorflow model을 pickling하지 못해서 나오는 이슈로 보임

- 아래 코드는 error가 나올 것 : PicklingError: Can't pickle...

# 데이터 생성
data = [("row1", 1.0, 2.0, 3.0),
        ("row2", 4.0, 5.0, 6.0)]

columns = ["id", "feature1", "feature2", "feature3"]
df = spark.createDataFrame(data, columns)

features = ["feature2", "feature3"]

model = tf.keras.models.load_model(model_path)

@F.pandas_udf(returnType=ArrayType(DoubleType()))
def predict_pandas_udf(*cols):
    import numpy as np
    import pandas as pd

    df = pd.concat(cols, axis=1, keys=features)
    y_pred = model(dict(df))

    return pd.Series(list(tensor.numpy() for tensor in y_pred))
				
pred_sdf = (
    sdf.select(
        F.col("id"),
        predict_pandas_udf(*features).alias("predictions"),
    )

 

방안2) woker 노드에서 직접 모델을 로드해서 pandas udf로 inference하는 방안

- worker 노드에서 직접 모델 path를 참고하여 로드하기 때문에 정상 동작

- driver에서 리소스를 관리하는 방향에 대해서 올바른 방안인지는 의문이 생김

# 데이터 생성
data = [("row1", 1.0, 2.0, 3.0),
        ("row2", 4.0, 5.0, 6.0)]

columns = ["id", "feature1", "feature2", "feature3"]
df = spark.createDataFrame(data, columns)

features = ["feature2", "feature3"]


@F.pandas_udf(returnType=ArrayType(DoubleType()))
def predict_pandas_udf(*cols):
    import numpy as np
    import pandas as pd

    df = pd.concat(cols, axis=1, keys=features)
	model = tf.keras.models.load_model(model_path)

    return pd.Series(list(tensor.numpy() for tensor in model(dict(df))))
			
            
pred_sdf = (
    sdf.select(
        F.col("id"),
        predict_pandas_udf(*features).alias("predictions"),
    )

 

multi class 분류의 경우

predictions의 컬럼값이 아래와 같이 array형태로 들어가게 될텐데,,

[[0.3, 0.4, 0.3],

[0.2, 0.2, 0.6],

[0.3, 0.3, 0.3], ...]

 

이를 class별로 풀려면 아래와 같이!!

 pred_sdf = (
        sdf.select(
            F.col("id"),
            predict_pandas_udf(*features).alias("predictions"),
        )
        .select(
            F.col("id"),
            *[
                F.col("predictions")[i].alias(f"Class_{c}")
                for i, c in enumerate(range(num_classes))
            ],
        )
    )
    
    print(pred_sdf.take(10))

 

728x90
반응형
LIST
Comments