- Published on
Jax Nedir?
- Authors
- Name
- İkbal Ünal
- Github
- @ikbalunal
JAX
- JAX yüksek performanslı makine öğrenimi ve yüksek performanslı sayısal hesaplama için açık kaynaklı bir çerçevedir. Google tarafından geliştirildi.
- Ne olmasını istediğinizi tanımlarsınız fonksiyonel olarak ve onun hızlı çalışmasını sağlamak için JAX kullanılır.
- Diğer çerçevelerle karşılaştırıldığında (pytorch,tensorflow) çoklu gpu paralelliği için çok iyi bir dahili desteğe sahiptir.
Python ve JIT Derleme Python normalde yorumlanan bir dildir. Python kodu, önce bytecode'a dönüştürülür ve ardından Python Sanal Makinesi (PVM) tarafından yorumlanarak çalıştırılır. JIT (Just-In-Time) derleme, bu süreci hızlandırmak için kullanılır. JIT derleyici, kodu çalıştırılmadan hemen önce makine koduna derler ve bu sayede kodun daha hızlı çalışmasını sağlar.
Normalde, bir fonksiyon her çağrıldığında Python tarafından yorumlanır ve çalıştırılır.
- Python'un Çalışma Şekli Python, hem derleme hem de yorumlama aşamalarını içerir:
- Bytecode Derleme: Python kodu önce bytecode denilen ara koda derlenir. Yorumlama: Bu bytecode, Python Sanal Makinesi (PVM) tarafından yorumlanarak çalıştırılır.
- JIT (Just-In-Time) Derleme JIT derleme, yorumlanan dillerde performansı artırmak için kullanılır:
- Anında Derleme: Kod, çalıştırılmadan hemen önce derlenir. Performans Artışı: Bu, kodun daha hızlı çalışmasını sağlar çünkü tekrar tekrar yorumlanmasına gerek kalmaz. Örnek: JAX, Java (JVM).
JAX'te JIT Derleme Nasıl Çalışır? Fonksiyon Tanımlama: İlk olarak, normal bir Python fonksiyonu tanımlarsınız. Bu fonksiyon, JAX tarafından derlenebilir ve optimize edilebilir. JIT ile Derleme: Bu fonksiyonu JIT derleme ile işaretlersiniz. JAX, bu işaretlemeyi gördüğünde fonksiyonu makine koduna derler. Hızlı Çalışma: Derlenmiş fonksiyon, normal fonksiyondan çok daha hızlı çalışır, çünkü tekrar tekrar yorumlanması gerekmez.
Derleme ve Yorumlama Arasındaki Farklar
Zamanlama:
Derleme: Program çalıştırılmadan önce derlenir. Yorumlama: Program çalıştırılırken yorumlanır. Çıktı:
Derleme: Çalıştırılabilir bir dosya (makine kodu) oluşturur. Yorumlama: Çalıştırılabilir bir dosya oluşturmaz; kod satır satır çalıştırılır. Performans:
Derleme: Derlenmiş kod daha hızlı çalışır çünkü doğrudan makine kodudur. Yorumlama: Yorumlanmış kod genellikle daha yavaştır çünkü her satır kod çalıştırılmadan önce yorumlanmalıdır. Hata Ayıklama:
Derleme: Hatalar derleme aşamasında tespit edilir, bu da bazen hata ayıklamayı zorlaştırabilir. Yorumlama: Hatalar anında görülebilir, bu da daha hızlı hata ayıklama sağlar. Esneklik:
Derleme: Derlenmiş kod platforma bağımlıdır. Yorumlama: Yorumlayıcı olan her platformda çalışabilir.
JAX'in nasıl çalıştığı ve rastgele sayı üretme konusundaki farklılıklar:
- Pure Functions (Saf Fonksiyonlar): JAX'teki temel bir kavram, fonksiyonların "saf" olmasıdır. Bu, aynı girdi verildiğinde her zaman aynı çıktıyı üreteceklerini ve işlev içinde herhangi bir global durumu değiştirmeyeceklerini belirtir.
- Random Number Generation (Rastgele Sayı Üretimi): JAX'te, rastgele sayı üretimi global bir durumu değiştirdiği için saf fonksiyonlarla çelişir. Bu nedenle, rastgelelik durumunu işlevler arasında taşımak ve global durumu değiştirmek yerine, rastgeleliği durumsuz bir şekilde ele almak gerekir.
- PRNG Key: Bu durumsuz rastgelelik yaklaşımı için JAX, bir PRNG anahtarı (PRNGKey) kullanır. Bu anahtar, rastgelelik durumunu temsil eder ve işlevler arasında geçirilerek rastgele sayı üretimi sağlanır.
- Reproducible Randomness (Tekrarlanabilir Rastgelelik): JAX, rastgeleliği varsayılan olarak sabit tohumlarla ele alır. Bu, kodunuzu tekrar çalıştırdığınızda aynı rastgele sayı dizisini alacağınız anlamına gelir, bu da deneylerinizi tekrar edilebilir hale getirir.
- Örnek Kod Parçası: Verilen bir tohum (seed) ile başlayarak bir anahtar oluşturulur. Bu anahtar, daha sonra alt anahtarlar oluşturmak için bölünebilir. Bu alt anahtarlar, rastgele sayı üretmek için kullanılır.
JAX'in bu yaklaşımı, özellikle makine öğrenimi ve bilimsel hesaplama gibi alanlarda, rastgelelikle ilgili işlemleri daha tutarlı ve tekrar edilebilir hale getirir. Bu, Numpy'den farklılık gösterir çünkü Numpy, global bir durumu değiştirerek rastgelelik sağlar, bu da fonksiyonların saf olmamasına yol açabilir.
jax.random.split
fonksiyonu, bir PRNG'den (Pseudo-random number generator) çıkan rastgele sayıları birden fazla parçaya bölmek için kullanılır. Bu işlem, genellikle aynı rastgele sayıları birden fazla bağımsız işlemde kullanmak istediğinizde veya veriyi rastgele ikiye ayırmanız gerektiğinde kullanılır.
Örneğin, eğitim ve doğrulama veri kümesi oluştururken veriyi rastgele ikiye bölmek istediğinizi düşünelim. Bu durumda, aynı rastgele sırayı kullanarak iki farklı veri kümesi oluşturmak isteyebilirsiniz. jax.random.split
fonksiyonunu kullanarak aynı rastgele sayı dizisini iki farklı parçaya bölebilir ve bu parçaları kullanarak iki farklı veri kümesi oluşturabilirsiniz.
Kısacası, jax.random.split
fonksiyonu, verilen PRNG'den gelen rastgele sayıları belirtilen sayıda parçaya böler. Bu işlem, aynı rastgele sayı dizisini farklı amaçlar için kullanmanızı sağlar.
Örneğin İki parçaya bölmek için jax.random.split
fonksiyonunu kullanabilirsiniz. İşte basit bir örnek:
import jax
import jax.numpy as jnp
from jax import random
# Anahtar oluşturma
key = random.PRNGKey(0)
# Rastgele sayıları oluşturma
rand_nums = random.normal(key, shape=(10,))
# Rastgele sayıları iki parçaya ayırma
rand_nums_1, rand_nums_2 = random.split(key, 2)
print("İlk parça:", rand_nums_1)
print("İkinci parça:", rand_nums_2)
Bu örnekte, random.split(key, 2)
ile rand_nums
dizisini iki parçaya böldük. key
değişkeni, rastgele sayıları üretmek için kullanılan anahtarı temsil eder. rand_nums_1
ve rand_nums_2
değişkenleri, rand_nums
dizisinin iki parçaya bölünmüş halleridir. Bu şekilde, rand_nums_1
ve rand_nums_2
değişkenleri arasında aynı rastgele sayı dizisinden üretilmiş rastgele sayılar bulunur.
Pekala, 100 parçaya bölmek istediğinizde 100 değişken tanımlamanıza gerek yok. Python'un yıldızlı ifadesini (*
) kullanarak istediğiniz sayıda değişkene atama yapabilirsiniz. Örneğin:
import jax
import jax.numpy as jnp
from jax import random
# Anahtar oluşturma
key = random.PRNGKey(0)
# Rastgele sayıları oluşturma
rand_nums = random.normal(key, shape=(100,))
# Rastgele sayıları 100 parçaya ayırma
rand_num_slices = random.split(key, 100)
# İlk 5 parçayı yazdırma
for i in range(5):
print(f"Parça {i+1}:", rand_num_slices[i])
Bu örnekte, random.split(key, 100)
ile rand_nums
dizisini 100 parçaya böldük ve rand_num_slices
adlı bir listeye atadık. Daha sonra for
döngüsü kullanarak ilk 5 parçayı yazdırdık, ancak bu yöntemi kullanarak 100 parçayı tek tek işleyebilirsiniz.
Subkey oluşturma
Bir anahtar (key
) oluşturuyoruz ve bu anahtarı kullanarak 5 adet alt anahtar (subkeys
) oluşturuyoruz. Her bir alt anahtar, rastgele sayı dizisini parçalara bölerken kullanılır ve her bir alt anahtar farklı bir parçayı temsil eder.
import jax
from jax import random
# Bir anahtar oluştur
key = random.PRNGKey(0)
# Anahtarı 5 parçaya böl ve alt anahtarları oluştur
key, *subkeys = jax.random.split(key, 5)
# Alt anahtarları yazdır
for i, subkey in enumerate(subkeys):
print(f"Alt Anahtar {i+1}: {subkey}")
Bu kod parçasında, split
fonksiyonu ile anahtarı 5 parçaya böldük ve her bir alt anahtarı subkeys
listesine ekledik. Sonra, subkeys
listesindeki her bir alt anahtarı yazdırdık. Bu sayede, her bir alt anahtarın farklı bir rastgele sayı dizisini temsil ettiğini görebilirsiniz.
Bu anahtarlar ne işe yarıyor?
Bu anahtarlar, JAX'te kullanılan Pseudo-random number generator (PRNG) için başlangıç noktasını belirler. Rastgele sayı üretmek için kullanılan PRNG, bir başlangıç noktası veya "anahtar" gerektirir. Bu anahtar, PRNG'nin her çalıştırıldığında aynı rastgele sayı dizisini üretmesini sağlar.
jax.random.split
fonksiyonu, verilen anahtardan türetilen alt anahtarları kullanarak farklı rastgelelik örneklerini oluşturur. Bu, farklı parçalarda aynı rastgele sayı dizisini kullanmanızı sağlar. Örneğin, model eğitimi sırasında farklı parçalarda (örneğin, farklı işlemcilerde veya cihazlarda) aynı rastgele sayı dizisini kullanarak eğitimi paralelleştirebilirsiniz.
Anahtarlar, aynı zamanda rastgelelik örneğinin izlenmesini sağlar. Örneğin, eğitimde rastgelelik kullanılıyorsa ve sonuçları tekrar üretmek gerekiyorsa, aynı anahtarı kullanarak aynı rastgele sayı dizisini yeniden oluşturabilirsiniz. Bu, deneylerin tekrarlanabilirliğini sağlar.
1.anahtarı kullanarak bir sayı üretmek için aşağıdaki gibi bir işlem yapabilirsiniz:
import jax
from jax import random
# Bir anahtar oluştur
key = random.PRNGKey(0)
# Anahtarı 4 parçaya böl ve alt anahtarları oluştur
key, *subkeys = jax.random.split(key, 4)
# İkinci anahtarı kullanarak bir sayı üret
sayi = random.normal(subkeys[1], ())
print("Üretilen Sayı:", sayi)
Bu örnekte, subkeys[1]
ifadesiyle 2. alt anahtarı (subkeys
listesindeki ikinci eleman) kullanarak bir sayı üretiyoruz. Bu, farklı alt anahtarlarla farklı rastgele sayı dizileri oluşturarak rastgelelik örneklerini kontrol etmenize olanak tanır.
JAX ile NumPy arasındaki bazı farklar
- İşlem Sırasında Değişiklik Yapılamaz: JAX, NumPy'den farklı olarak in-place (yerinde) işlemleri desteklemez. Yani, bir JAX dizisindeki bir elemanı doğrudan değiştiremezsiniz. Bunun yerine, işlevsel olarak eşdeğer olan
at
yöntemini kullanmanız gerekir. - JAX Fonksiyonları Sadece NumPy veya JAX Dizilerini Kabul Eder: NumPy, Python listelerini de kabul ederken, JAX sadece NumPy veya JAX dizilerini kabul eder. Bu, performansta sessiz bir düşüşü önlemek için hata fırlatmayı tercih ettiği anlamına gelir.
- Dizinin Sınırları Dışındaki İndeksleme Hatası Vermiyor: JAX'te, bir dizinin sınırları dışında bir indekse erişmeye çalıştığınızda hata almak yerine, indeksi dizinin sınırları içindeki bir değere kısıtlar.
- Zamanlama Farklılıkları: JAX ve NumPy arasındaki işlemlerin zamanlaması farklı olabilir. Özellikle matris çarpımı (
@
operatörü) gibi işlemlerde JAX'in performansı, NumPy'e göre farklılık gösterebilir.
JAX'in performansının NumPy'e göre neden daha yavaş olabileceği ve block_until_ready
fonksiyonunun zamanlama için neden gerektiği hakkında açıklama:
- Zamanlama Farklılıkları: JAX ve NumPy arasındaki çarpma işlemindeki zamanlama farklılığı şu şekildedir. JAX'in performansının NumPy'e göre daha yavaş olmasının sebebi, JAX'in işlemleri asenkron olarak hedefleyerek gerçekleştirmesi ve Python'a kontrolü hemen geri vermesidir. Bu durum, işlem tamamlanmadan önce kontrolün Python'a dönmesiyle, gerçek hesaplama süresinden daha hızlı bir dönüş süresi elde edilmesine neden olabilir. Bu da doğru olmayan bir zamanlama sonucuna yol açabilir.
block_until_ready
fonksiyonu, işlemin tamamlanmasını beklemek için kullanılır ve zamanlamayı doğru yapabilmek için gereklidir. - JAX'in İşlevi: JAX'in asıl amacı grafikleri tanımlamak ve derleyiciye optimizasyon için izin vermektir. NumPy gibi işlemleri adım adım yaparak (eager olarak) çalıştırmak yerine, JAX ile grafikler tanımlayarak ve derleyiciye optimizasyonu bırakarak çalışmak daha etkilidir. Eğer JAX'i NumPy gibi adım adım kullanıyorsanız, optimizasyon için hiçbir alan bırakmamış olursunuz ve ek JAX üzerindeki işlemlerden kaynaklanan ek yavaşlık nedeniyle daha yavaş bir işlev elde edersiniz.
jax.jit’ e giriş
Bu metinde, JAX'in jax.jit
işlevini kullanarak nasıl hızlandırılabileceği anlatılmaktadır. Öncelikle, JAX'in standart işlem hızının düşük olmasının sebebinin, JAX'in işlemleri tek tek hedefleyerek gerçekleştirmesinden kaynaklandığı bilinmektedir. JAX'i etkili bir şekilde kullanmanın yolu, XLA'yı kullanarak birden çok işlemi -ideal olarak neredeyse tüm işlemleri- birlikte derlemektir.
jax.jit
işlevine veya @jax.jit
dekoratörüne derlemek istediğimiz işlevi geçirerek, derlenecek bölgeyi belirtebiliriz. Bu işlev derhal derlenmez, ancak ilk çağrıda derlenir - bu nedenle "just-in-time derleme" adı verilir.
Bu ilk çağrı sırasında, giriş dizilerinin şekilleri, bir hesaplama grafiğini izlemek için kullanılır. Python yorumcusuyla işlevi adım adım geçerek işlemleri tek tek yürütür ve ne olduğunu grafiğe kaydeder. Bu ara temsili XLA'ya verilebilir ve sonrasında derlenir, optimize edilir ve önbelleğe alınır. Bu önbellek, aynı işlevin aynı giriş dizi şekilleri ve veri türüyle çağrılması durumunda geri alınır; bu da izleme ve derleme sürecini atlayarak, yoğun şekilde optimize edilmiş, önceden derlenmiş ikili blob'u doğrudan çağırır.
Bu süreci görmek için bir örneğe bakalım:
def fn(W, b, x):
return x @ W + b
key, w_key, b_key, x_key = jax.random.split(key, 4)
W = jax.random.normal(w_key, (4, 2)),
b = jax.random.uniform(b_key, (2,))
x = jax.random.normal(x_key, (4,))
print("`fn` zamanı")
%timeit fn(W, b, x).block_until_ready()
print("`jax.jit(fn)` ilk çağrı zamanı")
jit_fn = jax.jit(fn)
%time jit_fn(W, b, x).block_until_ready()
print("`jit_fn` zamanı")
%timeit jit_fn(W, b, x).block_until_ready()
Out:
`fn` time
26.1 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
`jit_fn` first call (warmup) time
CPU times: user 35.8 ms, sys: 38 µs, total: 35.9 ms
Wall time: 36.3 ms
`jit_fn` time
7.62 µs ± 1.88 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Bu kod parçası, fn
işlevinin performansını, jax.jit
ile derlenmiş jit_fn
işlevinin performansıyla karşılaştırır. İlk çağrının daha uzun sürmesi beklenir, bu nedenle bu çağrıyı zamanlama testinden dışlamak önemlidir. Ayrıca, bu basit örnekte bile, derlenmiş işlevin orijinal işlevden çok daha hızlı çalıştığını görebiliriz.