데이터과학 삼학년

[Tensorflow] tf.math.top_k을 이용해서 probs, indices를 한번에! 본문

Machine Learning

[Tensorflow] tf.math.top_k을 이용해서 probs, indices를 한번에!

Dan-k 2023. 3. 18. 12:55
반응형

텐서플로우에도 파이토치의 top_k처럼 모델이 output으로 내뱉는 결과중 상위 k개를 뽑아내는 함수가 있다.

tf.math.top_k(output, k)

이 것을 통하면 prob과 indices를 한번에 확인 가능하다!!!

import tensorflow as tf

# 라벨 딕셔너리 정의
label_dict = {0: 'cat', 1: 'dog', 2: 'bird'}

# 가장 높은 값과 인덱스 찾기
values, indices = tf.math.top_k(output, k=1)
print("Predicted index:", indices.numpy()[0])print("Predicted value:", values.numpy()[0])

@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32),
                              tf.TensorSpec(shape=[None], dtype=tf.int32)])
def get_top_k_labels(values, indices):
    # tf.gather_nd 함수를 사용하여 인덱스에 해당하는 라벨을 가져옴
    labels = tf.gather_nd(tf.constant(list(label_dict.values())), tf.expand_dims(indices, axis=-1))
    return labels

# 예시 데이터 생성
#values = tf.constant([[0.2, 0.4, 0.6, 0.8], [0.1, 0.5, 0.3, 0.9]])
#indices = tf.constant([[1, 3, 2, 0], [2, 1, 0, 3]])

# get_top_k_labels 함수 호출
labels = get_top_k_labels(values, indices)

# 결과 출력
print(labels)
728x90
반응형
LIST
Comments