Tento blog bude příkladem metody použití metody „torch.argmax()“ v PyTorch.
Jak používat metodu „torch.argmax()“ v PyTorch?
Metoda „torch.argmax()“ bere jako vstup jakýkoli 1D nebo 2D tenzor a vrací tenzor, který obsahuje indexy/indexy maximálních hodnot podél daného rozměru.
Syntaxe metody „torch.argmax()“ je uvedena níže:
pochodeň. argmax ( < input_tensor > )
Chcete-li použít tuto metodu v PyTorch, projděte si následující příklady pro lepší pochopení:
Příklad 1: Použijte metodu „torch.argmax()“ s 1D tenzorem
V prvním příkladu vytvoříme 1D tenzor a použijeme s ním metodu „torch.argmax()“. Podívejme se na níže uvedený postup krok za krokem:
Krok 1: Import knihovny PyTorch
Nejprve importujte „ pochodeň ” pro použití metody “torch.argmax()”:
import pochodeňKrok 2: Vytvořte 1D tenzor
Poté vytvořte 1D tenzor a vytiskněte jeho prvky. Zde vytváříme následující „ Desítky1 'tensor ze seznamu pomocí ' pochodeň.tensor() Funkce ”:
Desítky1 = pochodeň. tenzor ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )
tisk ( Desítky1 )
To vytvořilo 1D tenzor, jak je vidět níže:
Krok 3: Najděte indexy maximální hodnoty
Nyní použijte „ pochodeň.argmax() ” funkce pro nalezení indexu/indexů maximální hodnoty v “ Desítky1 “tensor:
T1_ind = pochodeň. argmax ( Desítky1 )Krok 4: Tisk indexu maximální hodnoty
Nakonec zobrazte index maximální hodnoty ve vstupním tenzoru:
tisk ( 'Indexy:' , T1_ind )Níže uvedený výstup zobrazuje index maximální hodnoty v „ Desítky1 ” tensor tj. 4. To znamená, že nejvyšší hodnota tenzoru je na 4. indexu, který je “ 9 “:
Příklad 2: Použijte metodu „torch.argmax()“ s 2D tenzorem
Ve druhém příkladu vytvoříme 2D tenzor a použijeme s ním metodu „torch.argmax()“. Postupujeme podle uvedených kroků:
Krok 1: Import knihovny PyTorch
Nejprve importujte „ pochodeň ” pro použití metody “torch.argmax()”:
import pochodeňKrok 2: Vytvořte 2D tenzor
Poté použijte „ pochodeň.tensor() ” pro vytvoření 2D tenzoru a tisk jeho prvků. Zde vytváříme následující „ Desítky2 '2D tenzor:
Desítky2 = pochodeň. tenzor ( [ [ 4 , 1 , - 7 ] , [ patnáct , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )tisk ( Desítky2 )
To vytvořilo 2D tenzor, jak je vidět níže:
Krok 3: Najděte indexy maximální hodnoty
Nyní najděte index maximální hodnoty v „ Desítky2 'tensor pomocí ' pochodeň.argmax() Funkce ”:
T2_ind = pochodeň. argmax ( Desítky2 )Krok 4: Tisk indexu maximální hodnoty
Nakonec zobrazte index maximální hodnoty ve vstupním tenzoru:
tisk ( 'Indexy:' , T2_ind )Podle níže uvedeného výstupu se index maximální hodnoty v „ Desítky2 “ tenzor je „3“. To znamená, že nejvyšší hodnota tenzoru je na 3. indexu, který je „ patnáct “:
Krok 5: Najděte indexy maximální hodnoty podél sloupců
Kromě toho mohou uživatelé také najít indexy/indexy maximálních hodnot podél každého sloupce tenzoru. Můžeme například použít „ šero=0 ” argument s funkcí “torch.argmax()”. Vyhledá indexy maximálních hodnot podél sloupců v „ Desítky2 ” tensor a poté vytiskne tyto indexy:
col_index = pochodeň. argmax ( desítky2 , ztlumit = 0 )tisk ( 'Indexy ve sloupcích:' , col_index )
Níže uvedený výstup ukazuje indexy maximálních hodnot podél každého sloupce tenzoru:
Krok 6: Najděte indexy maximální hodnoty podél řádků
Podobně mohou uživatelé také najít indexy/indexy maximálních hodnot podél každého řádku tenzoru. Použijte například „ šero=1 ” s funkcí “torch.argmax()” k nalezení indexů maximálních hodnot podél řádků v tenzoru “Tens2” a poté tyto indexy vytisknout:
row_index = pochodeň. argmax ( desítky2 , ztlumit = 1 )tisk ( 'Indexy v řádcích:' , row_index )
Indexy maximální hodnoty podél každého řádku tenzoru „Tens2“ lze vidět níže:
Účinně jsme vysvětlili metodu použití metody „torch.argmax()“ v PyTorch.
Poznámka : Zde můžete přistupovat k našemu Zápisníku Google Colab odkaz .
Závěr
Chcete-li použít metodu „torch.argmax()“ v PyTorch, nejprve importujte „ pochodeň “knihovna. Poté vytvořte požadovaný 1D nebo 2D tenzor a zobrazte jeho prvky. Dále použijte „ pochodeň.argmax() ” metoda k nalezení/výpočtu indexů/indexů maximálních hodnot v tenzoru. Kromě toho mohou uživatelé také najít indexy maximální hodnoty podél každého řádku nebo sloupce v tenzoru pomocí „ ztlumit “argument. Nakonec zobrazte index maximální hodnoty ve vstupním tenzoru. Tento blog je příkladem metody použití metody „torch.argmax()“ v PyTorch.