TensorFlow vs. PyTorch vs. JAX: Deep-Learning-Frameworks im Vergleich

TensorFlow, PyTorch oder JAX: Welches Deep-Learning-Framework für Sie der Schlüssel zum Erfolg ist, hängt von verschiedenen Faktoren ab.


Foto: Evannovostro – shutterstock.com

Deep Learning verändert unser Leben auf vielfältige Weise: Ob Siri oder Alexa, die unseren Sprachbefehlen folgen, Echtzeit-Übersetzer auf unseren Smartphones oder die Computer-Vision-Technologie, die intelligente Roboter und autonomes Fahren ermöglicht. All diesen Deep-Learning-Anwendungsfällen ist gemein, dass sie auf einem der drei führenden Frameworks aufbauen:

Dieser Artikel wirft einen vergleichenden Blick auf die drei genannten Frameworks und gibt Ihnen Anhaltspunkte darüber, welche Stärken und Schwächen die Deep-Learning-Rahmenwerke mit sich bringen.

“Niemand wurde jemals gefeuert, weil er IBM gekauft hat” lautete das Mantra der IT-Branche in den 1970er und 1980er Jahren. Das könnte man auch mit Bezug auf TensorFlow umdichten – in den 2010er Jahren das Maß der Dinge in Sachen Deep Learning. Bekanntermaßen geriet IBM allerdings in den 1990er Jahren in schwierigeres Fahrwasser. Und TensorFlow? Ist das Framework knapp sieben Jahre nach seiner Veröffentlichung 2015 immer noch wettbewerbsfähig?

Um es kurz zu machen: Ja. Schließlich hat sich TensorFlow seit seinem Debüt auch weiterentwickelt: Bei TensorFlow 1.x ging es im Wesentlichen darum, statische Graphen auf eine sehr “un-python-eske” Weise zu erstellen. Seit TensorFlow 2.x steht nun auch der “Eager”-Modus zur Verfügung, um Modelle zu erstellen und direkt hinsichtlich ihrer Operationen auszuwerten – dadurch fühlt sich die Arbeit mit TensorFlow viel mehr nach PyTorch an.

Auf hoher Ebene bietet TensorFlow Keras für eine vereinfachte Entwicklung, auf niedriger Ebene den optimierenden XLA Compiler (Accelerated Linear Algebra) für Geschwindigkeit. Dieser wirkt in Sachen GPU-Leistungssteigerung Wunder und ist die primäre Methode, um sich TPUs (Tensor Processing Units) von Google zunutze zu machen. Diese bieten eine nicht dagewesene Performance, wenn es darum geht, Modelle in großem Maßstab zu trainieren. Dann gibt es noch all die Dinge, die TensorFlow schon seit Jahren gut macht:

  • TensorFlow Serving sorgt dafür, dass Modelle in einer wohldefinierten und wiederholbaren Weise auf einer ausgereiften Plattform bedient werden können.
  • TensorFlow.js und TensorFlow Lite ermöglichen es, Modellimplementierungen für das Web, für stromsparende Rechner wie Smartphones oder für ressourcenbeschränkte Geräte wie IoT-Devices neu auszurichten.
  • Da Google immer noch 100 Prozent seiner produktiven Einsätze mit TensorFlow stemmt, dürfen Sie zuversichtlich sein, dass TensorFlow auch Ihre Skalierung bewältigen kann.

Allerdings hat TensorFlow in den letzten Jahren unter einem gewissen “Energiemangel” gelitten, der schwer zu ignorieren ist. Das Upgrade von TensorFlow 1.x auf TensorFlow 2.x war – um es auf den Punkt zu bringen – brutal. So brutal, dass sich einige Unternehmen ob des zu erwartenden Aufwands, den sie hätten betreiben müssen, damit ihr Code auf der neuen Hauptversion richtig funktioniert, dafür entschieden, stattdessen auf PyTorch umzusteigen. Auch innerhalb der Forschungs-Community verlor TensorFlow zugunsten der Flexibilität von PyTorch an Boden.

Letztlich war auch die “Keras-Affäre” TensorFlow nicht zuträglich: Keras wurde vor zwei Jahren in die TensorFlow-Releases integriert – vor kurzem aber wieder in eine separate Bibliothek mit eigenem Release-Plan ausgegliedert. Die Abspaltung von Keras ist sicherlich nichts, was das tägliche Leben eines Entwicklers maßgeblich beeinflusst – als vertrauensbildende Maßnahme geht eine solche, öffentlichkeitswirksame Kehrtwende allerdings auch nicht durch. Abgesehen davon ist TensorFlow ein zuverlässiges Framework und beherbergt ein umfangreiches Deep-Learning-Ökosystem: Sie können Anwendungen und Modelle auf TensorFlow bauen, die sich nahezu beliebig skalieren lassen – und werden dabei in guter Gesellschaft sein. Dennoch ist TensorFlow im Jahr 2022 nicht mehr zwingend die erste Wahl.

Entsprechend dem gerade Gelesenen ist PyTorch nicht mehr der Emporkömmling, der sich an die Fersen von TensorFlow heftet. Stattdessen hat sich das Framework – in erster Linie im Forschungssektor aber zunehmend auch in Produktivumgebungen – zu einer treibenden Kraft in der Welt des Deep Learning entwickelt.

Da der “Eager”-Modus inzwischen sowohl für TensorFlow als auch für PyTorch zur Standardentwicklungsmethode geworden ist, scheint die automatische Differenzierung (autograd) von PyTorch, den “Krieg” gegen statische Graphen gewonnen zu haben. Im Gegensatz zu TensorFlow hat PyTorch seit der Abschaffung der Variable API in Version 0.4 keine größeren Brüche im Kerncode erfahren. Zuvor war Variable erforderlich, um Autograd mit Tensoren zu verwenden – jetzt ist alles ein Tensor. Das soll aber nicht heißen, dass es nicht hier und da ein paar Fehltritte gegeben hat: Wenn Sie zum Beispiel PyTorch verwendet haben, um ein Modell auf mehreren GPUs zu trainieren, sind Sie wahrscheinlich auf die Unterschiede zwischen DataParallel und dem aktuelleren DistributedDataParallel gestoßen. Sie sollten so gut wie immer DistributedDataParallel verwenden (DataParallel ist dennoch nicht wirklich veraltet).

  • PyTorch musste in Sachen XLA/TPU-Unterstützung lange hinter TensorFlow und JAX zurückstehen – das hat sich 2022 geändert: PyTorch kann jetzt auf TPU-VMs sowie TPU-Node-Support zurückgreifen.
  • Kombiniert wird das Ganze mit einer simplen Deployment-Option via Kommandozeile.
  • Wenn Sie sich nicht mit Code-Fitzeleien befassen wollen, stehen Ihnen übergeordnete Ergänzungen wie PyTorch Lightning zur Verfügung, die es Ihnen ermöglichen, sich auf Ihre eigentliche Arbeit zu konzentrieren, statt Trainingsschleifen neu zu schreiben.

Die Kehrseite der Medaille: Die Arbeit an PyTorch Mobile geht zwar weiter, es ist aber immer noch weit weniger ausgereift als TensorFlow Lite. In Bezug auf die Produktion bietet PyTorch Integrationen mit Framework-agnostischen Plattformen wie Kubeflow, während das TorchServe-Projekt Details der Bereitstellung wie Skalierung, Metriken und Batch-Inferenz händelt. Das bringt Ihnen alle MLOps-Vorzüge in einem kleinen Paket, das von den PyTorch-Entwicklern selbst gepflegt wird. Jeder, der sagt, PyTorch wäre nicht skalierbar, lügt. Meta setzt PyTorch seit Jahren in der Produktion ein. Dennoch gibt es Argumente dafür, dass PyTorch für sehr, sehr große Trainingsläufe, die eine Vielzahl von GPUs oder TPUs erfordern, nicht ganz so gut geeignet ist wie JAX.

Die Popularität von PyTorch hing in den letzten Jahren auch mit dem Erfolg der Transformers-Bibliothek von Hugging Face zusammen. Ja, Transformers unterstützt jetzt auch TensorFlow und JAX, aber es begann als PyTorch-Projekt und bleibt eng mit dem Framework verbunden. Mit dem Aufstieg der Transformers-Architektur, der Flexibilität von PyTorch für die Forschung und der Möglichkeit, so viele neue Modelle innerhalb weniger Tage oder Stunden nach der Veröffentlichung über den Modell-Hub von Hugging Face zu beziehen, ist es kein Wunder, dass PyTorch heutzutage überall Anklang findet.

Wenn Sie auf TensorFlow verzichten können, hat Google vielleicht noch etwas anderes für Sie: JAX ist ein Deep-Learning-Framework, das von Google entwickelt, gepflegt und verwendet wird – ist aber kein offizielles Google-Produkt. Ein Blick auf die Veröffentlichungen von Google/DeepMind aus dem letzten Jahr verdeutlicht jedoch, dass ein Großteil der Google-Forschung auf JAX übergegangen ist.

JAX kann man sich ganz einfach wie folgt vorstellen: Eine GPU/TPU-beschleunigte Version von NumPy, die auf magische Weise eine Python-Funktion vektorisieren und alle Ableitungsberechnungen für diese Funktionen durchführen kann. Schließlich verfügt JAX auch über eine JIT-Komponente (Just-In-Time), die Ihren Code für den XLA-Compiler optimiert, was zu erheblichen Leistungssteigerungen im Vergleich zu TensorFlow und PyTorch führt. Code lässt sich unter Umständen vier- bis fünfmal so schnell ausführen, indem er einfach in JAX reimplementiert wird – ohne wirkliche Optimierung. Da JAX auf NumPy-Ebene arbeitet, wird JAX-Code auf einer viel niedrigeren Ebene geschrieben als TensorFlow/Keras und, ja, auch PyTorch. Glücklicherweise gibt es ein kleines, aber wachsendes Ökosystem von umgebenden Projekten, die zusätzliche Bits hinzufügen:

Wenn Sie erst einmal mit einem Programm wie Flax gearbeitet haben, ist relativ einfach, neuronale Netze zu erstellen. Seien Sie sich jedoch darüber im Klaren, dass JAX auch ein paar Ecken und Kanten mitbringt: Zum Beispiel geht es mit Zufallszahlen anders um als viele andere Frameworks. Wenn Sie tief in einem Forschungsprojekt mit großen Modellen stecken, deren Training enorme Ressourcen erfordert, ist eine Konvertierung auf JAX eine Überlegung wert. Die Fortschritte, die das Framework in Bereichen wie deterministischem Training bietet, könnten den Wechsel schon alleine wert sein.

Welches Deep-Learning-Framework sollten Sie also verwenden? Diese Frage allgemeingültig zu beantworten, ist nicht möglich. Alles hängt von der Art des Problems ab, das Sie lösen wollen und von der Größenordnung, in der Sie Ihre Modelle zum Einsatz bringen wollen. Auch die Computing-Plattformen, auf die Sie dabei abzielen, spielen eine Rolle.

  • Wenn Sie im Text- und Bildbereich arbeiten und kleine oder mittlere Forschungsarbeiten mit dem Ziel durchführen, die Modelle in der Produktion einzusetzen, ist PyTorch dafür im Moment wahrscheinlich die beste Wahl.
  • Wollen Sie allerdings das letzte Quäntchen Leistung aus rechenschwachen Devices herauspressen, ist TensorFlow empfehlenswert.
  • Wenn Sie an Trainingsmodellen mit Dutzenden oder Hunderten von Milliarden von Parametern oder mehr arbeiten und diese hauptsächlich zu Forschungszwecken trainieren, dann sollten Sie JAX eine Chance geben.

(fm)

Dieser Beitrag basiert auf einem Artikel unserer US-Schwesterpublikation Infoworld.

https://www.computerwoche.de/a/deep-learning-frameworks-im-vergleich,3612680

Leave a Reply