Every year, as part of their coursework, students from the University of Warsaw, Poland get to work under the supervision of engineers from the NVIDIA Warsaw…
Every year, as part of their coursework, students from the University of Warsaw, Poland get to work under the supervision of engineers from the NVIDIA Warsaw office on challenging problems in deep learning and accelerated computing. We present the work of three M.Sc. students—Alicja Ziarko, Paweł Pawlik, and Michał Siennicki—who managed to significantly reduce the latency in TorToiSe, a multi-stage, diffusion-based, text-to-speech (TTS) model.
Alicja, Paweł, and Michał first learned about the recent advancements in speech synthesis and diffusion models. They chose the combination of classifier-free guidance and progressive distillation, which performs well in computer vision, and adapted it to speech synthesis, achieving a 5x reduction in diffusion latency without a regression in speech quality. Small perceptual speech tests confirmed the results. Notably, this approach does not require costly training from scratch on the original model.
Why speed up diffusion-based TTS?
Since the publication of WaveNet in 2016, neural networks have become the primary models for speech synthesis. In simple applications, such as synthesis for AI-based voice assistants, synthetic voices are almost indistinguishable from human speech. Such voices can be synthesized orders of magnitudes faster than real time, for instance with the NVIDIA NeMo AI toolkit.
However, achieving high expressivity or imitating a voice based on a few seconds of recorded speech (few-shot) is still considered challenging.
Denoising Diffusion Probabilistic Models (DDPMs) emerged as a generative technique that enables the generation of images of great quality and expressivity based on input text. DDPMs can be readily applied to TTS because a frequency-based spectrogram, which graphically represents a speech signal, can be processed like an image.
For instance, in TorToiSe, which is a guided diffusion-based TTS model, a spectrogram is generated by combining the results of two diffusion models (Figure 1). The iterative diffusion process involves hundreds of steps to achieve a high-quality output, significantly increasing latency compared to state-of-the-art TTS methods, which severely limits its applications.
In Figure 1, the unconditional diffusion model iteratively refines the initial noise until a high-quality spectrogram is obtained. The second diffusion model is further conditioned on the text embeddings produced by the language model.
Methods for speeding up diffusion
Existing latency reduction techniques in diffusion-based TTS can be divided into training-free and training-based methods.
Training-free methods do not involve training the network used to generate images by reversing the diffusion process. Instead, they only focus on optimizing the multi-step diffusion process. The diffusion process can be seen as solving ODE/SDE equations, so one way to optimize it is to create a better solver like DDPM, DDIM, and DPM, which lowers the number of diffusion steps. Parallel sampling methods, such as those based on Picard iterations or Normalizing Flows, can parallelize the diffusion process to benefit from parallel computing on GPUs.
Training-based methods focus on optimizing the network used in the diffusion process. The network can be pruned, quantized, or sparsified, and then fine-tuned for higher accuracy. Alternatively, its neural architecture can be changed manually or automatically using NAS. Knowledge distillation techniques enable distilling the student network from the teacher network to reduce the number of steps in the diffusion process.
Distillation in diffusion-based TTS
Alicja, Paweł, and Michał decided to use the distillation approach based on promising results in computer vision and its potential for an estimated 5x reduction in latency of the diffusion model at inference. They have managed to adapt progressive distillation to the diffusion part of a pretrained TorToiSe model, overcoming problems like the lack of access to the original training data.
Their approach consists of two knowledge distillation phases:
Mimicking the guided diffusion model output
Training another student model
In the first knowledge distillation phase (Figure 2), the student model is trained to mimic the output of the guided diffusion model at each diffusion step. This phase reduces latency by half by combining the two diffusion models into one model.
To address the lack of access to the original training data, text embeddings from the language model are passed through the original teacher model to generate synthetic data used in distillation. The use of synthetic data also makes the distillation process more efficient because the entire TTS, guided diffusion pipeline does not have to be invoked at each distillation step.
In the second progressive distillation phase (Figure 3), the newly trained student model serves as a teacher to train another student model. In this technique, the student model is trained to mimic the teacher model while reducing the number of diffusion steps by a factor of two. This process is repeated many times to further reduce the number of steps, while each time, a new student serves as the teacher for the next round of distillation.
A progressive distillation with seven iterations reduces the number of inference steps 7^2 times, from 4,000 steps on which the model was trained to 31 steps. This reduction results in a 5x speedup compared to the guided diffusion model, excluding the text embedding calculation cost.
The perceptual pairwise speech test shows that the distilled model (after the second phase) matches the quality of speech produced by the TTS model based on guided distillation.
As an example, listen to audio samples in Table 1 generated by the progressive distillation-based TTS model. The samples match the quality of the audio samples from the guided diffusion-based TTS model. If we simply reduced the number of distillation steps to 31, instead of using progressive distillation, the quality of the generated speech deteriorates significantly.
Speaker
Guided diffusion-based TTS model (2×80 diffusion steps)
Diffusion-based TTS after progressive distillation (31 diffusion steps)
Guided diffusion-based TTS model (naive reduction to 31 diffusion steps)
Table 1: Audio samples generated by diffusion-based TTS compared to the two baseline models
Conclusion
Collaborating with academia and assisting young students in shaping their future in science and engineering is one of the core NVIDIA values. Alicja, Paweł, and Michał’s successful project exemplifies the NVIDIA Warsaw, Poland office partnership with local universities.
The students managed to solve the challenging problem of speeding up the pretrained, diffusion-based, text-to-speech (TTS) model. They designed and implemented a knowledge distillation-based solution in the complex field of diffusion-based TTS, achieving a 5x speedup of the diffusion process. Most notably, their unique solution based on synthetic data generation is applicable to pretrained TTS models without access to the original training data.
We encourage you to explore NVIDIA Academic Programs and try out the NVIDIA NeMo Framework to create complete conversational AI (TTS, ASR, or NLP/LLM) solutions for the new era of generative AI.
This post covers best practices when working with shaders on NVIDIA GPUs. To get a high and consistent frame rate in your applications, see all Advanced…
This post covers best practices when working with shaders on NVIDIA GPUs. To get a high and consistent frame rate in your applications, see all Advanced API Performance tips.
Shaders play a critical role in graphics programming by enabling you to control various aspects of the rendering process. They run on the GPU and are responsible for manipulating vertices, pixels, and other data.
Constant buffer reads are most effective when threads in a warp access data uniformly. If you need divergent reads, use shader resource view (SRVs).
Typical cases where SRVs should be preferred over CBVs include the following:
Bones or skinning data
Lookup tables, like precomputed random numbers
To optimize buffers and group shared memory, use manual bit packing. When creating structures for packing data, consider the range of values a field can hold and choose the smallest datatype that can encompass this range.
Optimize control flow by providing hints of the expected runtime behavior.
Make sure to enable compile flag -all-resources-bound for DXC (or D3DCOMPILE_ALL_RESOURCES_BOUNDin FXC) if possible. This enables a larger set of driver-side optimizations.
Consider using the [FLATTEN] and [BRANCH] keywords where appropriate.
A conditional branch may prevent the compiler from hoisting long-latency instructions, such as texture fetches.
The [FLATTEN] keyword hints that the compiler is free to hoist and start the load operations before the statement has been evaluated.
Use Root Signature 1.1 to specify static data and descriptors to enable the driver to make the most optimal shader optimizations.
Keep the register use to a minimum. Register allocation could limit occupancy and may force the driver to spill registers to memory.
Prefer the use of gather instructions when loading single channel texture quads.
This will cut down the expected latency by almost 4x compared to the equivalent operation constructed from consecutive sample instructions.
Prefer structured buffers over raw buffers.
Structured buffers have stricter alignment requirements, which enables the driver to schedule more efficient load instructions.
Consider using numerical approximations or precomputed lookup tables of transcendental functions (exp, log, sin, cos, sqrt) in math-intensive shaders, for instance, physics simulations and denoisers.
To promote a fast path in the TEX unit, with up to 2x speedup, use point filtering in certain circumstances:
Low-resolution textures where point filtering is already an accurate representation.
Textures that are being accessed at their native resolution.
Not recommended
Don’t assume that half-precision floats are always faster than full precision and the reverse.
On NVIDIA Ampere GPUs, it’s just as efficient to execute FP32 as FP16 instructions. The overhead of converting between precision formats may just end up with a net loss.
NVIDIA Turing GPUs may benefit from using FP16 math, as FP16 can be issued at twice the rate of FP32.
Compute shaders
Compute shaders are used for general-purpose computations, from data processing and simulations to machine learning.
Recommended
Consider using wave intrinsics over group shared memory when possible for communication across threads.
Starting from SM 6.0, HLSL supports warp-wide wave intrinsics natively without the need for vendor-specific HLSL extensions. Consider using vendor-specific APIs only when the expected functionality is missing. For more information, see Unlocking GPU Intrinsics in HLSL.
To increase atomic throughput, use wave instructions to coalesce atomic operations across a warp.
To maximize cache locality and to improve L1 and L2 hit rate, try thread group ID swizzling for full-screen compute passes.
A good starting point is to target a thread group size corresponding to between two or eight warps. For instance, thread group size 8x8x1 or 16x16x1 for full-screen passes. Make sure to profile your shader and tune the dimensions based on profiling results.
Not recommended
Do not make your thread group size difficult to scale per platform and GPU architecture.
Specialization constants can be used in Vulkan to set the dimensions at pipeline creation time whereas HLSL requires the thread group size to be known at shader compile time.
Be careless of thread group launch latency.
If your CS has early-out conditions that are expected to early out in most cases, it might be better to choose larger thread group dimensions and cut down on the total number of thread groups launched.
Pixel shaders
Pixel shaders, also known as fragment shaders, are used to calculate effects on a per-pixel basis.
Recommended
Prefer the use of depth bounds test or stencil and depth testing over manual depth tests in pixel shaders.
Depth and stencil tests may discard entire 16×16 raster tiles down to individual pixels. Make sure that Early-Z is enabled.
Be mindful of the use patterns that may force the driver to disable Early-Z testing:
Conditional z-writes such as clip and discard
As an alternative consider using null blend ops instead
Pixel shader depth write
Writing to UAV resources
Consider converting your full screen pass to a compute shader if there’s a large difference in latency between warps.
Not recommended
Don’t use raster order view (ROV) techniques pervasively.
Guaranteeing order doesn’t come for free.
Always compare with alternative approaches like advanced blending ops and atomics.
Vertex shaders
Vertex shaders are used to calculate effects on a per-vertex basis.
Recommended
Prefer the use of compressed vertex formats.
Prefer the use of SRVs for skinning data over CBVs. This is a typical case of divergent CBV reads.
Geometry, domain, and hull shaders
Geometry, domain, and hull shaders are used to control, evaluate, and generate geometry, enabling tessellation to create a dynamic generation of surfaces and objects.
Recommended
Replace the geometry, domain, and hull shaders with the mesh shading capabilities introduced in NVIDIA Turing.
Enable the fast geometry path with the following configuration:
Fixed topology: Neither an expansion or reduction in the number of vertices.
Fixed primitive type: The input primitive type is equal to the output primitive type.
Immutable per-vertex attributes: The application cannot change the vertex attributes and can only copy them from the input to the output.
Mutable per-primitive attributes: The application can compute a single value for the whole primitive, which then is passed to the fragment shader stage. For example, it can compute the area of the triangle.
Acknowledgments
Thanks to Ryan Prescott, Ana Mihut, Katherine Sun, and Ivan Fedorov.
Posted by Stephan Rasp, Research Scientist, and Carla Bromberg, Program Lead, Google Research
In 1950, weather forecasting started its digital revolution when researchers used the first programmable, general-purpose computer ENIAC to solve mathematical equations describing how weather evolves. In the more than 70 years since, continuous advancements in computing power and improvements to the model formulations have led to steady gains in weather forecast skill: a 7-day forecast today is about as accurate as a 5-day forecast in 2000 and a 3-day forecast in 1980. While improving forecast accuracy at the pace of approximately one day per decade may not seem like a big deal, every day improved is important in far reaching use cases, such as for logistics planning, disaster management, agriculture and energy production. This “quiet” revolution has been tremendously valuable to society, saving lives and providing economic value across many sectors.
Now we are seeing the start of yet another revolution in weather forecasting, this time fueled by advances in machine learning (ML). Rather than hard-coding approximations of the physical equations, the idea is to have algorithms learn how weather evolves from looking at large volumes of past weather data. Early attempts at doing so go back to 2018 but the pace picked up considerably in the last two years when several large ML models demonstrated weather forecasting skill comparable to the best physics-based models. Google’s MetNet [1, 2], for instance, demonstrated state-of-the-art capabilities for forecasting regional weather one day ahead. For global prediction, Google DeepMind created GraphCast, a graph neural network to make 10 day predictions at a horizontal resolution of 25 km, competitive with the best physics-based models in many skill metrics.
Apart from potentially providing more accurate forecasts, one key advantage of such ML methods is that, once trained, they can create forecasts in a matter of minutes on inexpensive hardware. In contrast, traditional weather forecasts require large super-computers that run for hours every day. Clearly, ML represents a tremendous opportunity for the weather forecasting community. This has also been recognized by leading weather forecasting centers, such as the European Centre for Medium-Range Weather Forecasts’ (ECMWF) machine learning roadmap or the National Oceanic and Atmospheric Administration’s (NOAA) artificial intelligence strategy.
To ensure that ML models are trusted and optimized for the right goal, forecast evaluation is crucial. Evaluating weather forecasts isn’t straightforward, however, because weather is an incredibly multi-faceted problem. Different end-users are interested in different properties of forecasts, for example, renewable energy producers care about wind speeds and solar radiation, while crisis response teams are concerned about the track of a potential cyclone or an impending heat wave. In other words, there is no single metric to determine what a “good” weather forecast is, and the evaluation has to reflect the multi-faceted nature of weather and its downstream applications. Furthermore, differences in the exact evaluation setup — e.g., which resolution and ground truth data is used — can make it difficult to compare models. Having a way to compare novel and established methods in a fair and reproducible manner is crucial to measure progress in the field.
To this end, we are announcing WeatherBench 2 (WB2), a benchmark for the next generation of data-driven, global weather models. WB2 is an update to the original benchmark published in 2020, which was based on initial, lower-resolution ML models. The goal of WB2 is to accelerate the progress of data-driven weather models by providing a trusted, reproducible framework for evaluating and comparing different methodologies. The official website contains scores from several state-of-the-art models (at the time of writing, these are Keisler (2022), an early graph neural network, Google DeepMind’s GraphCast and Huawei’s Pangu-Weather, a transformer-based ML model). In addition, forecasts from ECMWF’s high-resolution and ensemble forecasting systems are included, which represent some of the best traditional weather forecasting models.
Making evaluation easier
The key component of WB2 is an open-source evaluation framework that allows users to evaluate their forecasts in the same manner as other baselines. Weather forecast data at high-resolutions can be quite large, making even evaluation a computational challenge. For this reason, we built our evaluation code on Apache Beam, which allows users to split computations into smaller chunks and evaluate them in a distributed fashion, for example using DataFlow on Google Cloud. The code comes with a quick-start guide to help people get up to speed.
Additionally, we provide most of the ground-truth and baseline data on Google Cloud Storage in cloud-optimized Zarr format at different resolutions, for example, a comprehensive copy of the ERA5 dataset used to train most ML models. This is part of a larger Google effort to provide analysis-ready, cloud-optimized weather and climate datasets to the research community and beyond. Since downloading these data from the respective archives and converting them can be time-consuming and compute-intensive, we hope that this should considerably lower the entry barrier for the community.
Assessing forecast skill
Together with our collaborators from ECMWF, we defined a set of headline scores that best capture the quality of global weather forecasts. As the figure below shows, several of the ML-based forecasts have lower errors than the state-of-the-art physical models on deterministic metrics. This holds for a range of variables and regions, and underlines the competitiveness and promise of ML-based approaches.
This scorecard shows the skill of different models compared to ECMWF’s Integrated Forecasting System (IFS), one of the best physics-based weather forecasts, for several variables. IFS forecasts are evaluated against IFS analysis. All other models are evaluated against ERA5. The order of ML models reflects publication date.
Toward reliable probabilistic forecasts
However, a single forecast often isn’t enough. Weather is inherently chaotic because of the butterfly effect. For this reason, operational weather centers now run ~50 slightly perturbed realizations of their model, called an ensemble, to estimate the forecast probability distribution across various scenarios. This is important, for example, if one wants to know the likelihood of extreme weather.
Creating reliable probabilistic forecasts will be one of the next key challenges for global ML models. Regional ML models, such as Google’s MetNet already estimate probabilities. To anticipate this next generation of global models, WB2 already provides probabilistic metrics and baselines, among them ECMWF’s IFS ensemble, to accelerate research in this direction.
As mentioned above, weather forecasting has many aspects, and while the headline metrics try to capture the most important aspects of forecast skill, they are by no means sufficient. One example is forecast realism. Currently, many ML forecast models tend to “hedge their bets” in the face of the intrinsic uncertainty of the atmosphere. In other words, they tend to predict smoothed out fields that give lower average error but do not represent a realistic, physically consistent state of the atmosphere. An example of this can be seen in the animation below. The two data-driven models, Pangu-Weather and GraphCast (bottom), predict the large-scale evolution of the atmosphere remarkably well. However, they also have less small-scale structure compared to the ground truth or the physical forecasting model IFS HRES (top). In WB2 we include a range of these case studies and also a spectral metric that quantifies such blurring.
Forecasts of a front passing through the continental United States initialized on January 3, 2020. Maps show temperature at a pressure level of 850 hPa (roughly equivalent to an altitude of 1.5km) and geopotential at a pressure level of 500 hPa (roughly 5.5 km) in contours. ERA5 is the corresponding ground-truth analysis, IFS HRES is ECMWF’s physics-based forecasting model.
Conclusion
WeatherBench 2 will continue to evolve alongside ML model development. The official website will be updated with the latest state-of-the-art models. (To submit a model, please follow these instructions). We also invite the community to provide feedback and suggestions for improvements through issues and pull requests on the WB2 GitHub page.
Designing evaluation well and targeting the right metrics is crucial in order to make sure ML weather models benefit society as quickly as possible. WeatherBench 2 as it is now is just the starting point. We plan to extend it in the future to address key issues for the future of ML-based weather forecasting. Specifically, we would like to add station observations and better precipitation datasets. Furthermore, we will explore the inclusion of nowcasting and subseasonal-to-seasonal predictions to the benchmark.
We hope that WeatherBench 2 can aid researchers and end-users as weather forecasting continues to evolve.
Acknowledgements
WeatherBench 2 is the result of collaboration across many different teams at Google and external collaborators at ECMWF. From ECMWF, we would like to thank Matthew Chantry, Zied Ben Bouallegue and Peter Dueben. From Google, we would like to thank the core contributors to the project: Stephan Rasp, Stephan Hoyer, Peter Battaglia, Alex Merose, Ian Langmore, Tyler Russell, Alvaro Sanchez, Antonio Lobato, Laurence Chiu, Rob Carver, Vivian Yang, Shreya Agrawal, Thomas Turnbull, Jason Hickey, Carla Bromberg, Jared Sisk, Luke Barrington, Aaron Bell, and Fei Sha. We also would like to thank Kunal Shah, Rahul Mahrsee, Aniket Rawat, and Satish Kumar. Thanks to John Anderson for sponsoring WeatherBench 2. Furthermore, we would like to thank Kaifeng Bi from the Pangu-Weather team and Ryan Keisler for their help in adding their models to WeatherBench 2.
NVIDIA Jetson Orin is the best-in-class embedded platform for AI workloads. One of the key components of the Orin platform is the second-generation Deep…
NVIDIA Jetson Orin is the best-in-class embedded platform for AI workloads. One of the key components of the Orin platform is the second-generation Deep Learning Accelerator (DLA), the dedicated deep learning inference engine that offers one-third of the AI compute on the AGX Orin platforms.
This post is a deep technical dive into how embedded developers working with Orin platforms can deploy deep neural networks (DNNs) using YOLOv5 as a reference. To learn more about how DLA can help maximize the performance of your deep learning applications, see Maximizing Deep Learning Performance on NVIDIA Jetson Orin with DLA.
YOLOv5 is an object detection algorithm. Building on the success of v3 and v4, YOLOv5 aims to provide improved accuracy and speed in real-time object detection tasks. YOLOv5 has gained notoriety due to its excellent trade-off between accuracy and speed, making it a popular choice among researchers and practitioners in the field of computer vision. Its open-source implementation enables developers to leverage pretrained models and customize them according to specific goals.
Train a YOLOv5 model with Quantization-Aware Training (QAT) and export it for deployment on DLA.
Deploy the network and run inference using CUDA through TensorRT and cuDLA.
Execute on-target YOLOv5 accuracy validation and performance profiling.
Using this sample, we demonstrate how to achieve 37.3 mAP on the COCO dataset with DLA INT8 (official FP32 mAP is 37.4). We also show how to obtain over 400 FPS for YOLOv5 on a single NVIDIA Jetson Orin DLA. (A total of two DLA instances are available on Orin.)
QAT training and export for DLA
To balance the inference performance and accuracy of YOLOv5, it’s essential to apply Quantization-Aware-Training (QAT) on the model. Because DLA does not support QAT through TensorRT at the time of writing, it’s necessary to convert the QAT model to a Post-Training Quantization (PTQ) model before inference. The steps are outlined in Figure 1.
QAT training workflow
Use the TensorRT pytorch-quantization toolkit to quantize YOLOv5. The first step is to add quantizer modules to the neural network graph. This toolkit provides a set of quantized layer modules for common DL operations. If a module is not among the provided quantized modules, you can create a custom quantization module for the right place in the model.
The second step is to calibrate the model, obtaining the scale values for each Quantization/Dequantization (Q/DQ) module. After the calibration is complete, select a training schedule and fine-tune the calibrated model using the COCO dataset.
Adding Q/DQ nodes
There are two options for adding Q/DQ nodes to your network:
Option 1: Place Q/DQ nodes, as recommended, in TensorRT Processing of Q/DQ Networks. This method follows TensorRT fusion strategy for Q/DQ layers. These TensorRT strategies are mostly tuned for GPU inference. To make this compatible with DLA, add additional Q/DQ nodes, which can be derived using the scales from their neighboring layers with the Q/DQ Translator.
Any missing scales would otherwise result in certain layers running in FP16. This may result in a slight decrease in mAP and possibly a large performance drop. The Orin DLA is optimized for INT8 convolutions, about 15x over FP16 dense performance (or 30x when comparing dense FP16 to INT8 sparse performance).
Option 2: Insert Q/DQ nodes at every layer to make sure all tensors have INT8 scales. With this option, all layers’ scales can be obtained during model fine-tuning. However, this method may potentially disrupt TensorRT fusion strategy with Q/DQ layers when running inference on GPU and lead to higher latency on the GPU. For DLA, on the other hand, the rule of thumb with PTQ scales is, “The more available scales, the lower the latency.”
As confirmed by experiment, our YOLOv5 model was verified on the COCO 2017 validation dataset with a resolution of 672 x 672 pixels. Option 1 and Option 2, respectively, achieved mAP scores of 37.1 and 37.0.
Choose the best option based on your needs. If you already have an existing QAT workflow for GPU and would like to preserve it as much as possible, Option 1 is probably better. (You may need to extend Q/DQ Translator to infer more missing scales to achieve optimal DLA latency as well.)
On the other hand, if you are looking for a QAT training method that inserts Q/DQ nodes into all layers and is compatible with DLA, Option 2 may be your most promising.
Q/DQ Translator workflow
The purpose of the Q/DQ Translator is to translate an ONNX graph trained with QAT, to PTQ tensor scales and an ONNX model without Q/DQ nodes.
For this YOLOv5 model, extract quantization scales from Q/DQ nodes in the QAT model. Use the information of neighboring layers to infer the input/output scales of other layers such as Sigmoid and Mul in YOLOv5’s SiLU or for Concat nodes. After scales are extracted, export the ONNX model without Q/DQ nodes and the (PTQ) calibration cache file such that TensorRT can use them to build a DLA engine.
Deploying network to DLA for inference
The next step is to deploy the network and run inference using CUDA through TensorRT and cuDLA.
Loadable build with TensorRT
Use TensorRT to build the DLA loadable. This provides an easy-to-use interface for DLA loadable building and seamless integration with GPU if needed. For more information about TensorRT-DLA, see Working with DLA in the TensorRT Developer Guide.
trtexec is a convenient tool provided by TensorRT for building engines and benchmarking performance. Note that a DLA loadable is the result of successful DLA compilation through the DLA Compiler, and that TensorRT can package DLA loadables inside of serialized engines.
First, prepare the ONNX model and the calibration cache generated in the previous section. The DLA loadable can be built with a single command. Pass the --safe option and the entire model can run on DLA. This directly saves the compilation result as a serialized DLA loadable (without a TensorRT engine wrapping around it). For more details about this step, see the NVIDIA Deep Learning TensorRT Documentation.
Note that the input format dla_hwc4 is highly recommended from a performance point of view, if your model input qualifies. The input must have at most four input channels and be consumed by a convolution. In INT8, DLA can benefit from a specific hardware and software optimization that is not available if you use --inputIOFormats=int8:chw32 instead, for example.
Running inference using cuDLA
cuDLA is the CUDA runtime interface for DLA, an extension of the CUDA programming model that integrates DLA with CUDA. cuDLA enables you to submit DLA tasks using CUDA programming constructs. You can run inference using cuDLA either implicitly through TensorRT runtime or you can explicitly call the cuDLA APIs. This sample demonstrates the latter approach to explicitly call cuDLA APIs to run inference in hybrid mode and standalone mode.
cuDLA hybrid mode and standalone mode mainly differ in synchronization. In hybrid mode, DLA tasks are submitted to a CUDA stream, so synchronization can be done seamlessly with other CUDA tasks.
In standalone mode, the cudlaTask structure has a provision to specify wait and signal events that cuDLA must wait on and signal respectively, as part of cudlaSubmitTask.
In short, using cuDLA hybrid mode can give quick integration with other CUDA tasks. Using cuDLA standalone mode can prevent the creation of CUDA context, and thus can save resources if the pipeline has no CUDA context.
The primary cuDLA APIs used in this YOLOv5 sample are detailed below.
cudaMalloc and cudlaMemRegister are called to first allocate memory on GPU, then let the CUDA pointer be registered with the DLA. (Used only for hybrid mode.)
cudlaSubmitTask is called to submit the inference task. In hybrid mode, users need to specify the CUDA stream to let cuDLA tasks run on it. In standalone mode, users need to specify the signal event and wait event to let cuDLA wait and signal when the corresponding fence expires.
On-target validation and profiling
It’s important to note the numerical differences between GPU to DLA. The underlying hardware is different, so the computations are not bit-wise accurate. Because training the network is done on the GPU and then deployed to DLA on the target, it’s important to validate on the target. This specifically comes into play when it comes to quantization. It’s also important to compare against a reference baseline.
YOLOv5 DLA accuracy validation
We used the COCO dataset to validate. Figure 3 shows the inference pipeline architecture. First, load the image data and normalize it. Extra reformats on the inference inputs and outputs are needed because DLA only supports INT8/FP16.
After inference, decode the inference result and perform NMS (non-maximum suppression) to get the detection result. Finally, save the result and compute mAP.
In the case of YOLOv5, the feature maps of the last three convolution layers encode final detection information. When quantized to INT8, the quantization error of the bounding box coordinates becomes noticeable compared to FP16/FP32, thus affecting the final mAP.
Our experiment shows that running the last three convolution layers in FP16 improves the final mAP from 35.9 to 37.1. Orin DLA has a special hardware design highly optimized for INT8, so we observe a performance drop when these three convolutions run in FP16.
Table 1. Configurations exploring mixed precision for the last three convolution layers
Note that the mAP results are based on Option 1 described in the preceding section on adding Q/DQ nodes. You can apply the same principle to Option 2 as well.
YOLOv5 DLA performance
DLA offers one-third of AI compute on Orin AGX platforms, thanks to the two DLA cores. For a general baseline of Orin DLA performance, see Deep-Learning-Accelerator-SW on GitHub.
In the latest release, DLA 3.14.0 (DOS 6.0.8.0 and JetPack 6.0), several performance optimizations were added to the DLA compiler that specifically apply for INT8 CNN architecture-based models:
Native INT8 Sigmoid (previously ran in FP16 and had to be cast to and from INT8; also applies to Tanh)
INT8 SiLU fusion into a single DLA HW operation (instead of standalone Sigmoid plus standalone elementwise Mul)
Fusing the INT8 SiLU HW op with the previous INT8 Conv HW op (also applies to standalone Sigmoid or Tanh)
These improvements can provide a 6x speedup for YOLO architectures compared to prior releases. For instance, in the case of YOLOv5, the inference performance jumped from 13 ms to 2.4 ms in INT8 (with a few layers running in FP16), which is a 5.4x improvement. Further, you can use the cuDLA sample to profile your DNN layer-wise, identify bottlenecks, and modify your network to improve its performance.
Get started with DLA
This post explains how to run an entire object detection pipeline on Orin in the most efficient way using YOLOv5 on its dedicated Deep Learning Accelerator. Keep in mind that other SoC components such as the GPU are either idling or running at very small load. If you had a single camera producing inputs at 30 fps, one DLA instance would only be loaded at about 10%. So there is plenty of headroom for adding more bells and whistles to your application.
Ready to dive in? The YOLOv5 sample replicates the entire workflow discussed here. You can use it as a reference point for your own use case.
For beginners, the Jetson_dla_tutorial on GitHub demonstrates a basic DLA workflow to help you get started deploying a simple model to DLA.
Ray and path tracing algorithms construct light paths by starting at the camera or the light sources and intersecting rays with the scene geometry. As objects…
Ray and path tracing algorithms construct light paths by starting at the camera or the light sources and intersecting rays with the scene geometry. As objects are hit, new secondary rays are generated on these surfaces to continue the paths.
In theory, these secondary rays will not yield an intersection with the same triangle again, as intersections at a distance of zero are excluded by the intersection algorithm. In practice, however, the finite floating-point precision used in the actual implementation often leads to false-positive results, known as self-intersections (Figure 2). This creates artifacts, such as shadow acne, where the triangle sometimes improperly shadows itself (Figure 1).
Self-intersection can be avoided by explicitly excluding the same primitive from intersection using its identifier. In DirectX Raytracing (DXR) this self-intersection check would be implemented in an any-hit shader. However, forcing an any-hitinvocation for all triangle hits comes at a significant performance penalty. Furthermore, this method does not deal with false positives against adjacent (near) coplanar triangles.
The most widespread solutions to work around the issue use various heuristics to offset the ray along either the ray direction or the normal. These methods are, however, not robust enough to handle a variety of common production content and may even require manual parameter tweaking on a per-scene basis, particularly in scenes with heavily translated, scaled or sheared instanced geometry. For more information, see Ray Tracing Gems: High-Quality and Real-Time Rendering with DXR and Other APIs.
Alternatively, the sources of the numerical imprecision can be numerically bounded at runtime, giving robust error intervals on the intersection test. However, this comes with considerable performance overhead and requires source access to the underlying implementation of the ray/triangle intersection routine, which is not possible in a hardware-accelerated API like DXR.
This post describes a robust offsetting method for secondary rays spawned from triangles in DXR. The method is based on a thorough numerical analysis of the sources of the numerical imprecision. It involves computing spawn points for secondary rays, safe from self-intersections. The method does not require modification of the traversal and ray/triangle intersection routines and can thus be used with closed-source and hardware-accelerated ray tracing APIs like DXR. Finally, the method does not rely on self-intersection rejection using an any-hit shader and has a fixed overhead per shading point.
Method overview
The spawn point of a secondary ray coincides with the hit point on a triangle of an incoming ray. The goal is to compute a spawn point as close as possible to the hit point in the triangle plane, while still avoiding self-intersections. Too close to the triangle may result in self-intersection artifacts, but too far away may push the spawn point past nearby geometry, causing light leaking artifacts.
Figure 2 shows the sources of numerical error for secondary rays. In the user shader, the object-space hit point is reconstructed and transformed into world-space. During DXR ray traversal, the world-space ray is transformed back into object-space and intersected against triangles.
Each of these operations accumulates numerical errors, possibly resulting in self-intersections. This method computes a minimal uncertainty interval centered around the intended ray origin (red dot in Figure 2) on the triangle at each operation. The approximate ray origin (black dot in Figure 2) lies within this uncertainty interval. The ray origin is offset along the triangle normal beyond the final uncertainty interval to prevent self-intersections.
Hit point
Start by reconstructing the hit point and the geometric triangle normal in object-space (Listing 1).
precise float3 edge1 = v1 - v0;
precise float3 edge2 = v2 - v0;
// interpolate triangle using barycentrics
// add in base vertex last to reduce object-space error
precise float3 objPosition = v0 + mad(barys.x, edge1, mul(barys.y, edge2));
float3 objNormal = cross(edge1, edge2);
The hit point is computed by interpolating the triangle vertices v0, v1, and v2 using the 2D barycentric hit coordinates barys. Although it is possible to compute the interpolated hit point using two fused multiply-add operations, adding the base vertex v0 last reduces the maximum rounding error on the base vertex, which in practice dominates the rounding error in this computation.
Use the precise keyword to force the compiler to perform the computations exactly as specified. Enforced precise computation of the normal and the error bounds is not required. The effects of rounding errors on these quantities are vanishingly small and can safely be ignored for self-intersection.
Next, the object-space position is transformed into world-space (Listing 2).
Instead of using the HLSL matrix mul intrinsic, write out the transformation. This ensures that the translational part of the transformation is added last. This again reduces the rounding error on the translation, which in practice tends to dominate the error in this computation.
Finally, transform the object-space normal to world-space and normalize it (Listing 3).
const float3x4 w2o = WorldToObject3x4();
// transform normal to world-space using
// inverse transpose matrix
float3 wldNormal = mul(transpose((float3x3)w2o), objNormal);
// normalize world-space normal
const float wldScale = rsqrt(dot(wldNormal, wldNormal));
wldNormal = mul(wldScale, wldNormal);
// flip towards incoming ray
if(dot(WorldRayDirection(), wldNormal) > 0)
wldNormal = -wldNormal;
To support transformations with uneven scaling or shear, the normals are transformed using the inverse transpose transformation. There is no need to normalize the object-space normal before the transformation. It is necessary to normalize again in world-space anyway. Because the inverse length of the world normal is needed again later to appropriately scale the error bounds, normalize manually instead of using the HLSL normalize intrinsic.
Error bounds
With an approximate world-space position and triangle normal, continue by computing error bounds on the computed position, bounding the maximum finite precision rounding error. It is necessary to account for the rounding errors in the computations in Listings 1 and 2.
It is also necessary to account for rounding errors that may occur during traversal (Figure 2). During traversal, DXR will apply a world-to-object transformation and perform a ray-triangle intersection test. Both of these are performed in finite precision and thus introduce rounding errors.
Start by computing a combined object-space error bound, accounting both for the rounding errors in Listing 1 and rounding errors due to the DXR ray-triangle intersection test (Listing 4).
const float c0 = 5.9604644775390625E-8f;
const float c1 = 1.788139769587360206060111522674560546875E-7f;
// compute twice the maximum extent of the triangle
const float3 extent3 = abs(edge1) + abs(edge2) +
abs(abs(edge1) - abs(edge2));
const float extent = max(max(extent3.x, extent3.y), extent3.z);
// bound object-space error due to reconstruction and intersection
float3 objErr = mad(c0, abs(v0), mul(c1, extent));
Note that the error on the triangle intersection is bounded by the maximum triangle extent along the three dimensions. A rigorous proof for this bound goes beyond the scope of this post. To provide an intuitive justification, common ray-triangle intersection algorithms reorient the triangle into ’ray space’ (by subtracting the ray origin) before performing the intersection test. In the context of self-intersection, the ray origin lies on the triangle. Thus, the magnitude of the remaining triangle vertices in this ray space is bounded by the extent of the triangle along each dimension.
Furthermore, these intersection algorithms project the triangle into a 2D plane. This projection causes errors along one dimension to bleed over into the other dimensions. Therefore, take the maximum extent along all dimensions, instead of treating the error along the dimensions independently. The exact bound on the ray-triangle intersection test will be hardware-specific. The constant c1 is tuned for NVIDIA RTX hardware, but may require some adjusting on different platforms.
Error bounds for custom intersection primitives depend on the implementation details of their Intersection shader. See Advanced Linear Algebra: Foundations to Frontiers for a thorough introduction to finite precision rounding error analysis.
Next, compute the world-space error bound due to the transformation of the hit point from object-space to world-space (Listing 5).
That leaves the rounding errors in the world-to-object transformation performed by DXR during ray traversal (Listing 6).
// bound object-space error due to world-to-object transform
objErr = mad(c2, mul(abs(w2o), float4(abs(wldPosition), 1)), objErr);
Like the ray-triangle intersection test, the rounding error in the world-to-object transformation depends on the hardware. The constant c2 is conservative and should suffice for the various ways of implementing the vector matrix multiplication.
The finite precision representation of the world-to-object transformation matrix and its inverse are not guaranteed to match exactly. In the analysis, the error in the representation can be attributed to one or the other. Because the object-to-world transformation is performed in user code, the errors are best attributed to the object-to-world transformation matrix, enabling tighter bounds.
Offset
The previous section explained how to compute bounds on the rounding errors for secondary ray construction and traversal. These bounds yield an interval around the approximate, finite precision ray origin. The intended, full-precision ‘true’ ray origin is guaranteed to lie somewhere in this interval.
The true triangle passes through the true ray origin, so the triangle also passes through this interval. Figure 3 shows how to offset the approximate origin along the triangle normal to guarantee it lies above the true triangle, thus preventing self-intersections.
The error bound ∆ is projected onto the normal n to obtain an offset δ along the normal
Rounding errors on the normal are of similar magnitude as rounding errors on the computation of the error bounds and offset themselves. These are vanishingly small and can in practice be ignored. Combine the object and world-space offsets into a single world-space offset along the world-space normal (Listing 7).
Use the already normalized world-space normal from Listing 3. The world-space offset simplifies to . The object-space offset along the object-space normal needs to be transformed into world-space as .
Note, however, that the transformed object-space offset is not necessarily parallel to the world-space normal . To obtain a single combined offset along the world-space normal, project the transformed object-space offset onto the world-space normal, as . Using that this simplifies to:
Finally, use the computed offset to perturb the hit point along the triangle normal (Listing 8).
// offset along the normal on either side.
precise float3 wldFront = mad( wldOffset, wldNormal, wldPosition);
precise float3 wldBack = mad(-wldOffset, wldNormal, wldPosition);
This yields front and back spawn points safe from self-intersection. The derived error bounds (and thus offsets) neither depend on the incoming ray direction nor the outgoing secondary ray direction. It is therefore possible to reuse the same spawn points for all secondary rays originating from this hit point. All reflection rays should use the front spawn point while transmission rays should use the back spawn point.
Object-to-world and world-to-object transformations of the direction also cause rounding errors in the ray direction. At extreme grazing angles, these rounding errors may cause it to flip sides, orienting it back towards the triangle. The offsetting method in this post does not protect against such rounding errors. It is generally advised to filter out secondary rays at extreme angles.
Alternatively, similar error bounds can be derived on the ray direction transformations. Offsetting the ray direction along the triangle normal (as for the ray origin) can then guarantee its sidedness. However, as the reflectance distribution of common BRDF models tends towards zero at grazing angles, this problem can be safely ignored in many applications.
Object space
As seen in Listing 4, the offset grows linearly in the triangle extent and the magnitude of the triangle base vertex in object-space. For small triangles, the rounding error in the base vertex will dominate the object-space error (Figure 2). It is thus possible to reduce the object-space error by repositioning geometry in object-space, centering it around the object-space origin to minimize the distance to the origin. For geometry with extremely large triangles, such as ground planes, it may be worthwhile to tessellate the geometry and further reduce the rounding errors in the triangle extent.
Camera space
As seen in Listings 5 and 6, the magnitude of the offset will grow linearly with the magnitudes of the world-space position. The proportionality constant c2 is approximately 1 ulps. Instanced geometry at a distance from the scene origin in world-space will have a maximum rounding error in the order of , or 1 mm of offset for every 4 km distance. The offset magnitudes also scale linear with the triangle extent and object-space position.
For an example secondary ray in Figure 4 spawned on a leaf of 10 cm, in a tree of 20 m (object-space origin at the root) 1 km away from the world space origin, the offset magnitudes due to the triangle extent, object-space position, and world-space position will be in the order of 45 nm, 4 µm, and 0.25 mm, respectively. In practice, rounding errors in the world-space position tend to dominate all rounding errors. This is particularly true for large scenes of relatively small objects.
Note that the error is proportional to the world-space distance to the scene origin, not the scene camera. Consequently, if the camera is far away from the scene origin, the offsets for rays spawned from nearby geometry may become prohibitively large, resulting in visual artifacts.
This problem can be reduced by translating the entire scene into camera space. All instances are repositioned so the camera origin coincides with the world-space origin. Consequently, the distance becomes the distance to the camera in this camera space and the offset magnitudes will be proportional to the distance to the camera. Rays spawned from geometry near the camera will enjoy relatively small offsets, reducing the likelihood of visual artifacts due to offsetting.
Connection rays
This discussion has so far focused on offsetting of the ray origin to prevent self-intersection at the origin. Ray and path tracing algorithms also trace rays to evaluate visibility between two points on different triangles, such as shadow rays connecting a shading point and a light source.
These rays may suffer from self-intersection on either end of the ray. It is necessary to offset both ends to avoid self-intersections. The offset for the endpoint is computed in a similar fashion as for the ray origin, but using the object-to-world and world-to-object transformation matrices, barycentric and triangle vertices of the endpoint and using the connection ray direction as the incoming ray direction.
Contrary to scattering rays, it is necessary to account for rounding errors in the world-to-object ray direction transform during traversal. Theoretically, it is also necessary to account for additional rounding error in the ray-triangle intersection test because the ray origin does not lie on the endpoint triangle. However, this additional error scales sublinearly with the world-to-object error, so for simplicity these errors are implicitly combined.
For the endpoint, the world-to-object transformation error computation in Listing 6 is replaced by (Listing 9).
// connection ray direction
precise float3 wldDir = wldEndPosition - wldOrigin;
// bound endpoint object-space error due to object-to-world transform
float4 absOriginDir = (float4)(abs(wldOrigin) + abs(wldDir), 1);
objEndErr = mad(c2, mul(abs(w2oEnd), absOriginDir), objEndErr);
Here, wldOrigin is the connection ray origin in world-space. In DXR, rays are defined using an origin and direction. Instead of offsetting the endpoint and recomputing the ray direction, apply the offset directly to the world-space direction. For endpoint offsetting, Listing 8 thus becomes Listing 10.
// offset ray direction along the endpoint normal towards the ray origin
wldDir = mad(wldEndOffset, wldEndNormal, wldDir) ;
// shorten the ray tmax by 1 ulp
const float tmax = 0.99999994039f;
Shorten the ray length by 1 ulp to account for rounding errors in the direction computation.
In practice, a simpler approach of using a cheap approximate offsetting heuristic in combination with identifier-based self-intersection rejection is often sufficient to avoid endpoint self-intersection.The approximate offsetting will avoid most endpoint self-intersections, with identifier-based hit rejection taking care of the remaining self-intersections.
For secondary scatter rays, avoid identifier based self-intersection rejection, as it requires invoking an any-hit shader for every intersection along the ray, adding significant performance overhead. However, for visibility rays, the additional performance overhead of endpoint identifier-based hit rejection is minimal.
For visibility rays using the RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH flag there will always be at most two additional reported hits: the rejected endpoint self-intersection and any occluder terminating traversal.
For visibility rays not using the RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH flag, self-intersections can be rejected in the closest-hit shader instead of the any-hit shader. If the visibility ray invokes the closest-hit shader for the endpoint triangle, no closer hit was found and thus the hit should simply be treated as a miss in the closest-hit shader.
Conclusion
The method presented in this post offers a robust and easy-to-use solution for self-intersections of secondary rays. The method applies a minimal conservative offset, resolving self-intersection artifacts while reducing light leaking artifacts. Moreover, the method has minimal runtime overhead and integrates easily in common shading pipelines. While this post describes an HLSL implementation for DXR, the approach translates easily to GLSL for Vulkan and CUDA for OptiX.
Graph neural networks (GNNs) have emerged as a powerful tool for a variety of machine learning tasks on graph-structured data. These tasks range from node…
Graph neural networks (GNNs) have emerged as a powerful tool for a variety of machine learning tasks on graph-structured data. These tasks range from node classification and link prediction to graph classification. They also cover a wide range of applications such as social network analysis, drug discovery in healthcare, fraud detection in financial services, and molecular chemistry.
In this post, I introduce how to use cuGraph-DGL, a GPU-accelerated library for graph computations. It extends Deep Graph Library (DGL), a popular framework for GNNs that enables large-scale applications.
Basics of graph neural networks
Before I dive into cuGraph-DGL, I want to establish some basics. GNNs are a special kind of neural network designed to work with data structured as graphs. Unlike traditional neural networks that assume independence between samples, which doesn’t fit well with graph data, GNNs effectively exploit the rich and complex interconnections within graph data.
In a nutshell, GNNs work by propagating and transforming node features across the graph structure in multiple steps, often referred to as layers (Figure 1). Each layer updates the features of each node based on its own features and the features of its neighbors.
In Figure 1, the first step “prepares” a message composed of information from an edge and its connected nodes and then “passes” the message to the node. This process enables the model to learn high-level representations of nodes, edges, and the graph as a whole, which can be used for various downstream tasks like node classification, link prediction, and graph classification.
Figure 2 shows how a 2-layer GNN is supposed to compute the output of node 5.
Bottlenecks when handling large-scale graphs
The bottleneck in GNN sampling and training is the lack of an existing implementation that can scale to handle billions or even trillions of edges, a scale often seen in real-world graph problems. For example, if you’re handling a graph with trillions of edges, you must be able to run DGL-based GNN workflows quickly.
One solution is to use RAPIDS, which already possesses the foundational elements capable of scaling to trillions of edges using GPUs.
What is RAPIDS cuGraph?
cuGraph is a part of the RAPIDS AI ecosystem, an open-source suite of software libraries for executing end-to-end data science and analytics pipelines entirely on GPUs. The cuGraph library provides a simple, flexible, and powerful API for graph analytics, enabling you to perform computations on graph data at scale and speed.
What is DGL?
Deep Graph Library (DGL) is a Python library designed to simplify the implementation of graph neural networks (GNNs) by providing intuitive interfaces and high-performance computation.
DGL supports a broad array of graph operations and structures, enhancing the modeling of complex systems and relationships. It also integrates with popular deep learning frameworks like PyTorch and TensorFlow, fostering seamless development and deployment of GNNs.
What is cuGraph-DGL?
cuGraph-DGL is an extension of cuGraph that integrates with the Deep Graph Library (DGL) to leverage the power of GPUs to run DGL-based GNN workflows at unprecedented speed. This library is a collaborative effort between DGL developers and cuGraph developers.
In addition to cuGraph-DGL, cuGraph also provides the cugraph-ops library, which enables DGL users to get performance boosts using CuGraphSAGEConv, CuGraphGATConv, and CuGraphRelGraphConv in place of the default SAGEConv, GATConv, and RelGraphConv models. You can also import the SAGEConv, GATConv, and RelGraphConv models directly from the cugraph_dgl library.
In GNN sampling and training, the major challenge is the absence of an implementation that can manage real-world graph problems with billions or trillions of edges. To address this, use cuGraph-DGL, with its inherent capability to scale to trillions of edges using GPUs.
Setting up cuGraph-DGL
Before you dive into the code, make sure that you have cuGraph and DGL installed in your Python environment. To install the cuGraph-DGL-enabled environment, run the following command:
With your environment set up, put cuGraph-DGL into action and construct a simple GNN for node classification. Converting an existing DGL workflow to a cuGraph-DGL workflow has the following steps:
Use cuGraph-ops models such as CuGraphSAGECon, in place of the native DGL model (SAGEConv).
Create a CuGraphGraph object from a DGL graph.
Use the cuGraph data loader in place of the native DGL Dataloader.
Using cugraph-dgl on a 3.2 billion-edge graph, we observed a 3x speedup when using eight GPUs for sampling and training, compared to a single GPU UVA DGL setup. Additionally, we saw a 2x speedup when using eight GPUs for sampling and one GPU for training.
An upcoming blog post will provide more details on the gains and scalability.
Create a cuGraph-DGL graph
To create a cugraph_dgl graph directly from a DGL graph, run the following code example.
import dgl
import cugraph_dgl
dataset = dgl.data.CoraGraphDataset()
dgl_g = dataset[0]
# Add self loops as cugraph
# does not support isolated vertices yet
dgl_g = dgl.add_self_loop(dgl_g)
cugraph_g = cugraph_dgl.convert.cugraph_storage_from_heterograph(dgl_g, single_gpu=True)
For more information about creating a cuGraph storage object, see CuGraphStorage.
Create a cuGraph-Ops-based model
In this step, the only modification to make is the importation of cugraph_ops-based models. These models are drop-in replacements for upstream models like dgl.nn.SAGECon.
# Drop in replacement for dgl.nn.SAGEConv
from dgl.nn import CuGraphSAGEConv as SAGEConv
import torch.nn as nn
import torch.nn.functional as F
class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# three-layer GraphSAGE-mean
self.layers.append(SAGEConv(in_size, hid_size, "mean"))
self.layers.append(SAGEConv(hid_size, hid_size, "mean"))
self.layers.append(SAGEConv(hid_size, out_size, "mean"))
self.dropout = nn.Dropout(0.5)
self.hid_size = hid_size
self.out_size = out_size
def forward(self, blocks, x):
h = x
for l_id, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
if l_id != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h
# Create the model with given dimensions
feat_size = cugraph_g.ndata["feat"]["_N"].shape[1]
model = SAGE(feat_size, 256, dataset.num_classes).to("cuda")
Train the model
In this step, you opt to use cugraph_dgl.dataloading.NeighborSampler and cugraph_dgl.dataloading.DataLoader, replacing the conventional data loaders of upstream DGL.
import torchmetrics.functional as MF
import tempfile
import torch
def train(g, model):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
features = g.ndata["feat"]["_N"].to("cuda")
labels = g.ndata["label"]["_N"].to("cuda")
train_nid = torch.tensor(range(g.num_nodes())).type(torch.int64)
temp_dir_name = tempfile.TemporaryDirectory().name
for epoch in range(10):
model.train()
sampler = cugraph_dgl.dataloading.NeighborSampler([10,10,10])
dataloader = cugraph_dgl.dataloading.DataLoader(g, train_nid, sampler,
batch_size=128,
shuffle=True,
drop_last=False,
num_workers=0,
sampling_output_dir=temp_dir_name)
total_loss = 0
for step, (input_nodes, seeds, blocks) in enumerate((dataloader)):
batch_inputs = features[input_nodes]
batch_labels = labels[seeds]
batch_pred = model(blocks, batch_inputs)
loss = F.cross_entropy(batch_pred, batch_labels)
total_loss += loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
sampler = cugraph_dgl.dataloading.NeighborSampler([-1,-1,-1])
dataloader = cugraph_dgl.dataloading.DataLoader(g, train_nid, sampler,
batch_size=1024,
shuffle=False,
drop_last=False,
num_workers=0,
sampling_output_dir=temp_dir_name)
acc = evaluate(model, features, labels, dataloader)
print("Epoch {:05d} | Acc {:.4f} | Loss {:.4f} ".format(epoch, acc, total_loss))
def evaluate(model, features, labels, dataloader):
with torch.no_grad():
model.eval()
ys = []
y_hats = []
for it, (in_nodes, out_nodes, blocks) in enumerate(dataloader):
with torch.no_grad():
x = features[in_nodes]
ys.append(labels[out_nodes])
y_hats.append(model(blocks, x))
num_classes = y_hats[0].shape[1]
return MF.accuracy(
torch.cat(y_hats),
torch.cat(ys),
task="multiclass",
num_classes=num_classes,
)
train(cugraph_g, model)
Epoch 00000 | Acc 0.3401 | Loss 39.3890
Epoch 00001 | Acc 0.7164 | Loss 27.8906
Epoch 00002 | Acc 0.7888 | Loss 16.9441
Epoch 00003 | Acc 0.8589 | Loss 12.5475
Epoch 00004 | Acc 0.8863 | Loss 9.9894
Epoch 00005 | Acc 0.8948 | Loss 9.0556
Epoch 00006 | Acc 0.9029 | Loss 7.3637
Epoch 00007 | Acc 0.9055 | Loss 7.2541
Epoch 00008 | Acc 0.9132 | Loss 6.6912
Epoch 00009 | Acc 0.9121 | Loss 7.0908
Conclusion
By combining the power of GPU-accelerated graph computations with the flexibility of DGL, cuGraph-DGL emerges as an invaluable tool for anyone dealing with graph data.
This post has only scratched the surface of what you can do with cuGraph-DGL. I encourage you to explore further, experiment with different GNN architectures, and discover how cuGraph-DGL can accelerate your graph-based, machine-learning tasks.
Academics Mory Gharib and Alireza Ramezani in 2020 were spitballing a transforming robot that is now getting a shot at work that’s literally out of this world: NASA Mars Rover missions. Caltech has unveiled its multi-talented robot that can fly, drive, walk and do eight permutations of motions through a combination of its skills. They Read article >
Entrepreneurs are cultivating generative AI from the west coast of Africa to the eastern edge of the Arabian Desert. Gen AI is the latest of the big plans Kofi Genfi and Nii Osae have been hatching since they met 15 years ago in high school in Accra, Ghana’s capital that sits on the Gulf of Read article >