TensorFlow:tf.map_fn函數(shù)

2018-10-30 17:51 更新
函數(shù):tf.map_fn
map_fn(
    fn,
    elems,
    dtype=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    name=None
)

定義在:tensorflow/python/ops/functional_ops.py.

參見指南:高階函數(shù)>高階運算符

從0維度的 elems 中解壓的張量列表上的映射.

map_fn 的最簡單版本反復(fù)地將可調(diào)用的 fn 應(yīng)用于從第一個到最后一個的元素序列.這些元素由 elems 解壓縮的張量構(gòu)成.dtype 是 fn 的返回值的數(shù)據(jù)類型.如果與elems 的數(shù)據(jù)類型不同,用戶必須提供 dtype.

假設(shè) elems 被打包成 values、張量列表.結(jié)果張量的形狀是:[values.shape[0]] + fn(values[0]).shape.

這種方法也允許 fn 的多元 elems 和輸出.如果 elems 是(可能是嵌套的)列表或元素的張量,則這些張量中的每一個必須具有匹配的第一(unpack)維度.簽名fn可能匹配的結(jié)構(gòu)elems.也就是說,如果 elems 是:(t1, [t2, t3, [t4, t5]]),則 fn 的適當(dāng)簽名為:fn = lambda (t1, [t2, t3, [t4, t5]]):.

此外,fn 可能會發(fā)出與其輸入不同的結(jié)構(gòu).例如,fn 可能看起來像:fn = lambda t1: return (t1 + 1, t1 - 1).在這種情況下,dtype 參數(shù)不是可選的:dtype 必須是與 fn的輸出匹配的類型或(可能是嵌套的)元組.

要將函數(shù)操作應(yīng)用于 SparseTensor 的非零元素,建議使用以下方法之一.首先,如果函數(shù)可以表示為 TensorFlow ops,請使用:

result = SparseTensor(input.indices, fn(input.values), input.dense_shape)

但是,如果該函數(shù)不能作為 TensorFlow op 表示,則使用:

result = SparseTensor(
  input.indices, map_fn(fn, input.values), input.dense_shape)

參數(shù):

  • fn:可調(diào)用的執(zhí)行.它接受一個參數(shù),它將具有與之相同的(可能嵌套的)結(jié)構(gòu) elems.其輸出必須具有與 dtype 相同的結(jié)構(gòu)(如果提供了),否則它必須具有與elems 相同的結(jié)構(gòu).
  • elems:張量或(可能是嵌套的)張量序列,其中的每一個都將沿著它們的第一維度進(jìn)行解壓.生成的切片的嵌套序列將應(yīng)用于 fn.
  • dtype:(可選)fn 的輸出類型.如果 fn 返回與 elems 結(jié)構(gòu)不同的張量結(jié)構(gòu),則 dtype 不是可選的,并且必須具有與 fn 的輸出相同的結(jié)構(gòu).
  • parallel_iterations:(可選)允許并行運行的迭代次數(shù).
  • back_prop:(可選)True 允許支持反向傳播.
  • swap_memory:(可選)True 可實現(xiàn) GPU-CPU 內(nèi)存交換.
  • infer_shape:(可選)False 禁用對一致輸出形狀的測試.
  • name:(可選)返回的張量的名稱前綴.

返回值:

該函數(shù)返回張量或(可能是嵌套的)張量序列.每個張量都將 fn 的結(jié)果應(yīng)用到從第一個維度的 elems,從第一個到最后一個.

可能發(fā)生的異常:

  • TypeError:如果 fn 不是可調(diào)用或 fn 的輸出的結(jié)構(gòu)和 dtype 不匹配,或者 elems 是 SparseTensor.
  • ValueError:如果 fn 的輸出長度 和 dtype 不匹配.

例子:

elems = np.array([1, 2, 3, 4, 5, 6])
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]

elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
# alternate == [-1, 2, -3]

elems = np.array([1, 2, 3])
alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
# alternates[0] == [1, 2, 3]
# alternates[1] == [-1, -2, -3]


以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號