250x250
반응형
Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
Tags
- XAI
- grad-cam
- requests
- 유튜브 API
- 상관관계
- API
- 공분산
- session 유지
- flask
- chatGPT
- login crawling
- spark udf
- Airflow
- API Gateway
- GenericGBQException
- Counterfactual Explanations
- airflow subdag
- UDF
- Retry
- GCP
- top_k
- youtube data
- gather_nd
- BigQuery
- TensorFlow
- tensorflow text
- correlation
- hadoop
- subdag
- integrated gradient
Archives
- Today
- Total
데이터과학 삼학년
[Tensorflow] tf.math.top_k을 이용해서 probs, indices를 한번에! 본문
반응형
텐서플로우에도 파이토치의 top_k처럼 모델이 output으로 내뱉는 결과중 상위 k개를 뽑아내는 함수가 있다.
tf.math.top_k(output, k)
이 것을 통하면 prob과 indices를 한번에 확인 가능하다!!!
import tensorflow as tf
# 라벨 딕셔너리 정의
label_dict = {0: 'cat', 1: 'dog', 2: 'bird'}
# 가장 높은 값과 인덱스 찾기
values, indices = tf.math.top_k(output, k=1)
print("Predicted index:", indices.numpy()[0])print("Predicted value:", values.numpy()[0])
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32),
tf.TensorSpec(shape=[None], dtype=tf.int32)])
def get_top_k_labels(values, indices):
# tf.gather_nd 함수를 사용하여 인덱스에 해당하는 라벨을 가져옴
labels = tf.gather_nd(tf.constant(list(label_dict.values())), tf.expand_dims(indices, axis=-1))
return labels
# 예시 데이터 생성
#values = tf.constant([[0.2, 0.4, 0.6, 0.8], [0.1, 0.5, 0.3, 0.9]])
#indices = tf.constant([[1, 3, 2, 0], [2, 1, 0, 3]])
# get_top_k_labels 함수 호출
labels = get_top_k_labels(values, indices)
# 결과 출력
print(labels)
728x90
반응형
LIST
'Machine Learning' 카테고리의 다른 글
LGBM(LightGBM) (0) | 2023.05.10 |
---|---|
Pandas-AI (pandas 활용을 chatGPT 명령에 따라 실행) (0) | 2023.05.08 |
TensorFlow 모델 서빙 방법 (feat. ChatGPT) (0) | 2023.03.14 |
pytorch vocab 로드시 error 처리 (‘Vocab' object has no attribute '_modules' ~) (0) | 2023.02.22 |
Logistic Regression은 왜 linear classification 인가? (0) | 2022.11.03 |
Comments