Juliaの多次元配列において各行の最大値のインデックスを取得する方法を紹介します。
これは、例えばバッチ処理をするときなどに使えると思います。
まずは、配列から最大の要素を持つindexを取得する方法から紹介します。
1. 1次元配列(ベクトル)の場合
1-1. argmax()を使う
1つ目の方法は、argmax()を用いる方法です。例えば以下のようなベクトルを考えます。
a = [1 3 6 4 2 5]
argmax(a)
とすると、結果は以下のようになります。
CartesianIndex(1, 3)
これは(1, 3)番目が最大値であることを示しています。CartesianIndexから値を取り出したい場合は、1番目に「1」、2番目に「3」という値が格納されていますので、最大値のindexを取り出すには
argmax(a)[2]
とすれば良いということになります。
1-2. findmax()を使う
2つ目に紹介する方法は、findmax()を使う方法です。使い方はargmaxと同じで、引数に配列を渡します。
a = [1 3 6 4 2 5]
findmax(a)
そうすると、以下のような結果になります。
(6, CartesianIndex(1, 3))
こちらは、1番目に配列の最大値である「6」、2番目にindexである「CartesianIndex(1, 3)」が格納されていますので、indexが欲しい場合は、
findmax(a)[2][2]
のようにする必要があります。
2. 多次元配列の場合
2-1. argmax()を使う
多次元配列から最大値のindexを取得したい場合も、1次元配列の時と同じ関数を用いることができます。
a = [1 3; 5 2; 6 4]
argmax(a)
結果は以下のようになります。
CartesianIndex(3, 1)
CartesianIndex(3, 1)が要素の場所を表しています。
2-2. findmax()を使う
argmax()ではなくfindmax()を用いる方法もあります。
a = [1 3; 5 2; 6 4]
findmax(a)
結果は以下のようになります。
(6, CartesianIndex(3, 1))
2番目の要素のCartesianIndex(3, 1)が要素の場所を表しています。
次の3節では、本題である各行の最大値のインデックスを取得する方法を紹介します。
3. 多次元配列の各行からそれぞれ最大のindexを取得する方法
Batch処理の場合のように、各行からそれぞれ最大の値を持つ要素のindexを取得したい場合は、mapslices()を使います。
例えば以下のような多次元配列を考えます。
a = [1 3 2 4; 5 2 6 4; 9 4 7 6]
それぞれの行から最大値のindexを取得すると、[4, 3, 1]という配列が得られると思います。Juliaではこれをmapslices()を使って実現します。mapslices()は引数に関数、配列、次元の3つを渡します。今回の例だと関数にはargmax、配列はa、次元は2になります。
mapslices(argmax, a, dims=2)
そうすると以下のように結果が出力されます。
3×1 Array{Int64,2}:
4
3
1
これで望みの結果が得られました。
4. まとめ
本記事では、配列から最大値のindexを取得する方法を紹介しました。特に、主に以下の関数の使い方について説明しました。
・argmax関数
・findmax関数
・mapslices関数
必ずしもこれらの方法だけではないと思うので、1つの実装方法として知っていただければと思います。