1package jme3tools.optimize;
2
3import com.jme3.material.Material;
4import com.jme3.math.Matrix4f;
5import com.jme3.math.Transform;
6import com.jme3.math.Vector3f;
7import com.jme3.scene.Mesh.Mode;
8import com.jme3.scene.*;
9import com.jme3.scene.VertexBuffer.Format;
10import com.jme3.scene.VertexBuffer.Type;
11import com.jme3.scene.VertexBuffer.Usage;
12import com.jme3.scene.mesh.IndexBuffer;
13import com.jme3.util.BufferUtils;
14import com.jme3.util.IntMap.Entry;
15import java.nio.Buffer;
16import java.nio.FloatBuffer;
17import java.nio.ShortBuffer;
18import java.util.*;
19import java.util.logging.Logger;
20
21public class GeometryBatchFactory {
22
23    private static final Logger logger = Logger.getLogger(GeometryBatchFactory.class.getName());
24
25    private static void doTransformVerts(FloatBuffer inBuf, int offset, FloatBuffer outBuf, Matrix4f transform) {
26        Vector3f pos = new Vector3f();
27
28        // offset is given in element units
29        // convert to be in component units
30        offset *= 3;
31
32        for (int i = 0; i < inBuf.capacity() / 3; i++) {
33            pos.x = inBuf.get(i * 3 + 0);
34            pos.y = inBuf.get(i * 3 + 1);
35            pos.z = inBuf.get(i * 3 + 2);
36
37            transform.mult(pos, pos);
38
39            outBuf.put(offset + i * 3 + 0, pos.x);
40            outBuf.put(offset + i * 3 + 1, pos.y);
41            outBuf.put(offset + i * 3 + 2, pos.z);
42        }
43    }
44
45    private static void doTransformNorms(FloatBuffer inBuf, int offset, FloatBuffer outBuf, Matrix4f transform) {
46        Vector3f norm = new Vector3f();
47
48        // offset is given in element units
49        // convert to be in component units
50        offset *= 3;
51
52        for (int i = 0; i < inBuf.capacity() / 3; i++) {
53            norm.x = inBuf.get(i * 3 + 0);
54            norm.y = inBuf.get(i * 3 + 1);
55            norm.z = inBuf.get(i * 3 + 2);
56
57            transform.multNormal(norm, norm);
58
59            outBuf.put(offset + i * 3 + 0, norm.x);
60            outBuf.put(offset + i * 3 + 1, norm.y);
61            outBuf.put(offset + i * 3 + 2, norm.z);
62        }
63    }
64
65    private static void doTransformTangents(FloatBuffer inBuf, int offset, int components, FloatBuffer outBuf, Matrix4f transform) {
66        Vector3f tan = new Vector3f();
67
68        // offset is given in element units
69        // convert to be in component units
70        offset *= components;
71
72        for (int i = 0; i < inBuf.capacity() / components; i++) {
73            tan.x = inBuf.get(i * components + 0);
74            tan.y = inBuf.get(i * components + 1);
75            tan.z = inBuf.get(i * components + 2);
76
77            transform.multNormal(tan, tan);
78
79            outBuf.put(offset + i * components + 0, tan.x);
80            outBuf.put(offset + i * components + 1, tan.y);
81            outBuf.put(offset + i * components + 2, tan.z);
82
83            if (components == 4){
84                outBuf.put(offset + i * components + 3, inBuf.get(i * components + 3));
85            }
86        }
87    }
88
89    /**
90     * Merges all geometries in the collection into
91     * the output mesh. Creates a new material using the TextureAtlas.
92     *
93     * @param geometries
94     * @param outMesh
95     */
96    public static void mergeGeometries(Collection<Geometry> geometries, Mesh outMesh) {
97        int[] compsForBuf = new int[VertexBuffer.Type.values().length];
98        Format[] formatForBuf = new Format[compsForBuf.length];
99
100        int totalVerts = 0;
101        int totalTris = 0;
102        int totalLodLevels = 0;
103
104        Mode mode = null;
105        for (Geometry geom : geometries) {
106            totalVerts += geom.getVertexCount();
107            totalTris += geom.getTriangleCount();
108            totalLodLevels = Math.min(totalLodLevels, geom.getMesh().getNumLodLevels());
109
110            Mode listMode;
111            int components;
112            switch (geom.getMesh().getMode()) {
113                case Points:
114                    listMode = Mode.Points;
115                    components = 1;
116                    break;
117                case LineLoop:
118                case LineStrip:
119                case Lines:
120                    listMode = Mode.Lines;
121                    components = 2;
122                    break;
123                case TriangleFan:
124                case TriangleStrip:
125                case Triangles:
126                    listMode = Mode.Triangles;
127                    components = 3;
128                    break;
129                default:
130                    throw new UnsupportedOperationException();
131            }
132
133            for (VertexBuffer vb : geom.getMesh().getBufferList().getArray()){
134                compsForBuf[vb.getBufferType().ordinal()] = vb.getNumComponents();
135                formatForBuf[vb.getBufferType().ordinal()] = vb.getFormat();
136            }
137
138            if (mode != null && mode != listMode) {
139                throw new UnsupportedOperationException("Cannot combine different"
140                        + " primitive types: " + mode + " != " + listMode);
141            }
142            mode = listMode;
143            compsForBuf[Type.Index.ordinal()] = components;
144        }
145
146        outMesh.setMode(mode);
147        if (totalVerts >= 65536) {
148            // make sure we create an UnsignedInt buffer so
149            // we can fit all of the meshes
150            formatForBuf[Type.Index.ordinal()] = Format.UnsignedInt;
151        } else {
152            formatForBuf[Type.Index.ordinal()] = Format.UnsignedShort;
153        }
154
155        // generate output buffers based on retrieved info
156        for (int i = 0; i < compsForBuf.length; i++) {
157            if (compsForBuf[i] == 0) {
158                continue;
159            }
160
161            Buffer data;
162            if (i == Type.Index.ordinal()) {
163                data = VertexBuffer.createBuffer(formatForBuf[i], compsForBuf[i], totalTris);
164            } else {
165                data = VertexBuffer.createBuffer(formatForBuf[i], compsForBuf[i], totalVerts);
166            }
167
168            VertexBuffer vb = new VertexBuffer(Type.values()[i]);
169            vb.setupData(Usage.Static, compsForBuf[i], formatForBuf[i], data);
170            outMesh.setBuffer(vb);
171        }
172
173        int globalVertIndex = 0;
174        int globalTriIndex = 0;
175
176        for (Geometry geom : geometries) {
177            Mesh inMesh = geom.getMesh();
178            geom.computeWorldMatrix();
179            Matrix4f worldMatrix = geom.getWorldMatrix();
180
181            int geomVertCount = inMesh.getVertexCount();
182            int geomTriCount = inMesh.getTriangleCount();
183
184            for (int bufType = 0; bufType < compsForBuf.length; bufType++) {
185                VertexBuffer inBuf = inMesh.getBuffer(Type.values()[bufType]);
186                VertexBuffer outBuf = outMesh.getBuffer(Type.values()[bufType]);
187
188                if (inBuf == null || outBuf == null) {
189                    continue;
190                }
191
192                if (Type.Index.ordinal() == bufType) {
193                    int components = compsForBuf[bufType];
194
195                    IndexBuffer inIdx = inMesh.getIndicesAsList();
196                    IndexBuffer outIdx = outMesh.getIndexBuffer();
197
198                    for (int tri = 0; tri < geomTriCount; tri++) {
199                        for (int comp = 0; comp < components; comp++) {
200                            int idx = inIdx.get(tri * components + comp) + globalVertIndex;
201                            outIdx.put((globalTriIndex + tri) * components + comp, idx);
202                        }
203                    }
204                } else if (Type.Position.ordinal() == bufType) {
205                    FloatBuffer inPos = (FloatBuffer) inBuf.getDataReadOnly();
206                    FloatBuffer outPos = (FloatBuffer) outBuf.getData();
207                    doTransformVerts(inPos, globalVertIndex, outPos, worldMatrix);
208                } else if (Type.Normal.ordinal() == bufType) {
209                    FloatBuffer inPos = (FloatBuffer) inBuf.getDataReadOnly();
210                    FloatBuffer outPos = (FloatBuffer) outBuf.getData();
211                    doTransformNorms(inPos, globalVertIndex, outPos, worldMatrix);
212                }else if(Type.Tangent.ordinal() == bufType){
213                    FloatBuffer inPos = (FloatBuffer) inBuf.getDataReadOnly();
214                    FloatBuffer outPos = (FloatBuffer) outBuf.getData();
215                    int components = inBuf.getNumComponents();
216                    doTransformTangents(inPos, globalVertIndex, components, outPos, worldMatrix);
217                } else {
218                    inBuf.copyElements(0, outBuf, globalVertIndex, geomVertCount);
219                }
220            }
221
222            globalVertIndex += geomVertCount;
223            globalTriIndex += geomTriCount;
224        }
225    }
226
227    public static void makeLods(Collection<Geometry> geometries, Mesh outMesh) {
228        int lodLevels = 0;
229        int[] lodSize = null;
230        int index = 0;
231        for (Geometry g : geometries) {
232            if (lodLevels == 0) {
233                lodLevels = g.getMesh().getNumLodLevels();
234            }
235            if (lodSize == null) {
236                lodSize = new int[lodLevels];
237            }
238            for (int i = 0; i < lodLevels; i++) {
239                lodSize[i] += g.getMesh().getLodLevel(i).getData().capacity();
240                //if( i == 0) System.out.println(index + " " +lodSize[i]);
241            }
242            index++;
243        }
244        int[][] lodData = new int[lodLevels][];
245        for (int i = 0; i < lodLevels; i++) {
246            lodData[i] = new int[lodSize[i]];
247        }
248        VertexBuffer[] lods = new VertexBuffer[lodLevels];
249        int bufferPos[] = new int[lodLevels];
250        //int index = 0;
251        int numOfVertices = 0;
252        int curGeom = 0;
253        for (Geometry g : geometries) {
254            if (numOfVertices == 0) {
255                numOfVertices = g.getVertexCount();
256            }
257            for (int i = 0; i < lodLevels; i++) {
258                ShortBuffer buffer = (ShortBuffer) g.getMesh().getLodLevel(i).getDataReadOnly();
259                //System.out.println("buffer: " + buffer.capacity() + " limit: " + lodSize[i] + " " + index);
260                for (int j = 0; j < buffer.capacity(); j++) {
261                    lodData[i][bufferPos[i] + j] = buffer.get() + numOfVertices * curGeom;
262                    //bufferPos[i]++;
263                }
264                bufferPos[i] += buffer.capacity();
265            }
266            curGeom++;
267        }
268        for (int i = 0; i < lodLevels; i++) {
269            lods[i] = new VertexBuffer(Type.Index);
270            lods[i].setupData(Usage.Dynamic, 1, Format.UnsignedInt, BufferUtils.createIntBuffer(lodData[i]));
271        }
272        System.out.println(lods.length);
273        outMesh.setLodLevels(lods);
274    }
275
276    public static List<Geometry> makeBatches(Collection<Geometry> geometries) {
277        return makeBatches(geometries, false);
278    }
279
280    /**
281     * Batches a collection of Geometries so that all with the same material get combined.
282     * @param geometries The Geometries to combine
283     * @return A List of newly created Geometries, each with a  distinct material
284     */
285    public static List<Geometry> makeBatches(Collection<Geometry> geometries, boolean useLods) {
286        ArrayList<Geometry> retVal = new ArrayList<Geometry>();
287        HashMap<Material, List<Geometry>> matToGeom = new HashMap<Material, List<Geometry>>();
288
289        for (Geometry geom : geometries) {
290            List<Geometry> outList = matToGeom.get(geom.getMaterial());
291            if (outList == null) {
292                outList = new ArrayList<Geometry>();
293                matToGeom.put(geom.getMaterial(), outList);
294            }
295            outList.add(geom);
296        }
297
298        int batchNum = 0;
299        for (Map.Entry<Material, List<Geometry>> entry : matToGeom.entrySet()) {
300            Material mat = entry.getKey();
301            List<Geometry> geomsForMat = entry.getValue();
302            Mesh mesh = new Mesh();
303            mergeGeometries(geomsForMat, mesh);
304            // lods
305            if (useLods) {
306                makeLods(geomsForMat, mesh);
307            }
308            mesh.updateCounts();
309            mesh.updateBound();
310
311            Geometry out = new Geometry("batch[" + (batchNum++) + "]", mesh);
312            out.setMaterial(mat);
313            retVal.add(out);
314        }
315
316        return retVal;
317    }
318
319    public static void gatherGeoms(Spatial scene, List<Geometry> geoms) {
320        if (scene instanceof Node) {
321            Node node = (Node) scene;
322            for (Spatial child : node.getChildren()) {
323                gatherGeoms(child, geoms);
324            }
325        } else if (scene instanceof Geometry) {
326            geoms.add((Geometry) scene);
327        }
328    }
329
330    /**
331     * Optimizes a scene by combining Geometry with the same material.
332     * All Geometries found in the scene are detached from their parent and
333     * a new Node containing the optimized Geometries is attached.
334     * @param scene The scene to optimize
335     * @return The newly created optimized geometries attached to a node
336     */
337    public static Spatial optimize(Node scene) {
338        return optimize(scene, false);
339    }
340
341    /**
342     * Optimizes a scene by combining Geometry with the same material.
343     * All Geometries found in the scene are detached from their parent and
344     * a new Node containing the optimized Geometries is attached.
345     * @param scene The scene to optimize
346     * @param useLods true if you want the resulting geometry to keep lod information
347     * @return The newly created optimized geometries attached to a node
348     */
349    public static Node optimize(Node scene, boolean useLods) {
350        ArrayList<Geometry> geoms = new ArrayList<Geometry>();
351
352        gatherGeoms(scene, geoms);
353
354        List<Geometry> batchedGeoms = makeBatches(geoms, useLods);
355        for (Geometry geom : batchedGeoms) {
356            scene.attachChild(geom);
357        }
358
359        for (Iterator<Geometry> it = geoms.iterator(); it.hasNext();) {
360            Geometry geometry = it.next();
361            geometry.removeFromParent();
362        }
363
364        // Since the scene is returned unaltered the transform must be reset
365        scene.setLocalTransform(Transform.IDENTITY);
366
367        return scene;
368    }
369
370    public static void printMesh(Mesh mesh) {
371        for (int bufType = 0; bufType < Type.values().length; bufType++) {
372            VertexBuffer outBuf = mesh.getBuffer(Type.values()[bufType]);
373            if (outBuf == null) {
374                continue;
375            }
376
377            System.out.println(outBuf.getBufferType() + ": ");
378            for (int vert = 0; vert < outBuf.getNumElements(); vert++) {
379                String str = "[";
380                for (int comp = 0; comp < outBuf.getNumComponents(); comp++) {
381                    Object val = outBuf.getElementComponent(vert, comp);
382                    outBuf.setElementComponent(vert, comp, val);
383                    val = outBuf.getElementComponent(vert, comp);
384                    str += val;
385                    if (comp != outBuf.getNumComponents() - 1) {
386                        str += ", ";
387                    }
388                }
389                str += "]";
390                System.out.println(str);
391            }
392            System.out.println("------");
393        }
394    }
395
396    public static void main(String[] args) {
397        Mesh mesh = new Mesh();
398        mesh.setBuffer(Type.Position, 3, new float[]{
399                    0, 0, 0,
400                    1, 0, 0,
401                    1, 1, 0,
402                    0, 1, 0
403                });
404        mesh.setBuffer(Type.Index, 2, new short[]{
405                    0, 1,
406                    1, 2,
407                    2, 3,
408                    3, 0
409                });
410
411        Geometry g1 = new Geometry("g1", mesh);
412
413        ArrayList<Geometry> geoms = new ArrayList<Geometry>();
414        geoms.add(g1);
415
416        Mesh outMesh = new Mesh();
417        mergeGeometries(geoms, outMesh);
418        printMesh(outMesh);
419    }
420}
421